-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Config: unified logic to retrieve text config #33219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c8d23f2
07f27cb
6fdf0bc
5383d2f
6ee8149
3b46260
ccc24ee
b3698fb
5531106
6fe4249
34710c8
256bb41
6b7de20
332760c
3f4b1ac
1d7a9d5
9cba916
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,4 +67,4 @@ def main(): | |
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1019,17 +1019,17 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]: | |
| """ | ||
| non_default_generation_parameters = {} | ||
| decoder_attribute_name = None | ||
| default_config = None | ||
|
|
||
| # Composite models don't have a default config, use their decoder config as a fallback for default values | ||
| # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults | ||
| try: | ||
| default_config = self.__class__() | ||
| except ValueError: | ||
| for decoder_attribute_name in ("decoder", "generator", "text_config"): | ||
| if hasattr(self, decoder_attribute_name): | ||
| default_config = getattr(self, decoder_attribute_name).__class__() | ||
| break | ||
| decoder_config = self.get_text_config(decoder=True) | ||
| if decoder_config is not self: | ||
| default_config = decoder_config.__class__() | ||
| else: | ||
| decoder_config = None | ||
|
|
||
| # If it is a composite model, we want to check the subconfig that will be used for generation | ||
| self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name) | ||
|
|
@@ -1057,6 +1057,36 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]: | |
|
|
||
| return non_default_generation_parameters | ||
|
|
||
| def get_text_config(self, decoder=False) -> "PretrainedConfig": | ||
| """ | ||
| Returns the config that is meant to be used with text IO. On most models, it is the original config instance | ||
| itself. On specific composite models, it is under a set of valid names. | ||
|
|
||
| If `decoder` is set to `True`, then only search for decoder config names. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my own understanding, what does it mean to search in "decoder config names"? Is it somehow related to a model being an encoder-decoder or decoder-only? From what I see, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is indeed mostly for
|
||
| """ | ||
| decoder_possible_text_config_names = ("decoder", "generator", "text_config") | ||
| encoder_possible_text_config_names = ("text_encoder",) | ||
| if decoder: | ||
| possible_text_config_names = decoder_possible_text_config_names | ||
| else: | ||
| possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names | ||
|
|
||
| valid_text_config_names = [] | ||
| for text_config_name in possible_text_config_names: | ||
| if hasattr(self, text_config_name): | ||
| text_config = getattr(self, text_config_name, None) | ||
| if text_config is not None: | ||
| valid_text_config_names += [text_config_name] | ||
|
|
||
| if len(valid_text_config_names) > 1: | ||
| raise ValueError( | ||
| f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this " | ||
| "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." | ||
| ) | ||
| elif len(valid_text_config_names) == 1: | ||
| return getattr(self, valid_text_config_names[0]) | ||
| return self | ||
|
|
||
|
|
||
| def get_configuration_file(configuration_files: List[str]) -> str: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1747,12 +1747,13 @@ def test_resize_position_vector_embeddings(self): | |
| self.assertTrue(models_equal) | ||
|
|
||
| def test_resize_tokens_embeddings(self): | ||
| if not self.test_resize_embeddings: | ||
| self.skipTest(reason="test_resize_embeddings is set to `False`") | ||
|
Comment on lines
+1750
to
+1751
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (moved the skip here: no point in spending compute if we are going to skip the test) |
||
|
|
||
| ( | ||
| original_config, | ||
| inputs_dict, | ||
| ) = self.model_tester.prepare_config_and_inputs_for_common() | ||
| if not self.test_resize_embeddings: | ||
| self.skipTest(reason="test_resize_embeddings is set to `False`") | ||
|
|
||
| for model_class in self.all_model_classes: | ||
| config = copy.deepcopy(original_config) | ||
|
|
@@ -1764,18 +1765,15 @@ def test_resize_tokens_embeddings(self): | |
| if self.model_tester.is_training is False: | ||
| model.eval() | ||
|
|
||
| model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size | ||
| model_vocab_size = config.get_text_config().vocab_size | ||
| # Retrieve the embeddings and clone theme | ||
| model_embed = model.resize_token_embeddings(model_vocab_size) | ||
| cloned_embeddings = model_embed.weight.clone() | ||
|
|
||
| # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size | ||
| model_embed = model.resize_token_embeddings(model_vocab_size + 10) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
|
|
||
| self.assertEqual(new_model_vocab_size, model_vocab_size + 10) | ||
| # Check that it actually resizes the embeddings matrix | ||
| self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) | ||
|
|
@@ -1787,11 +1785,7 @@ def test_resize_tokens_embeddings(self): | |
|
|
||
| # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size | ||
| model_embed = model.resize_token_embeddings(model_vocab_size - 15) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
| self.assertEqual(new_model_vocab_size, model_vocab_size - 15) | ||
| # Check that it actually resizes the embeddings matrix | ||
| self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) | ||
|
|
@@ -1817,21 +1811,13 @@ def test_resize_tokens_embeddings(self): | |
| model = model_class(config) | ||
| model.to(torch_device) | ||
|
|
||
| model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size | ||
| model_vocab_size = config.get_text_config().vocab_size | ||
| model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
| self.assertTrue(new_model_vocab_size + 10, model_vocab_size) | ||
|
|
||
| model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
| self.assertTrue(model_embed.weight.shape[0] // 64, 0) | ||
|
|
||
| self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size) | ||
|
|
@@ -1852,13 +1838,10 @@ def test_resize_tokens_embeddings(self): | |
| model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) | ||
|
|
||
| def test_resize_embeddings_untied(self): | ||
| ( | ||
| original_config, | ||
| inputs_dict, | ||
| ) = self.model_tester.prepare_config_and_inputs_for_common() | ||
| if not self.test_resize_embeddings: | ||
| self.skipTest(reason="test_resize_embeddings is set to `False`") | ||
|
|
||
| original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| original_config.tie_word_embeddings = False | ||
|
|
||
| # if model cannot untied embeddings -> leave test | ||
|
|
@@ -1874,13 +1857,9 @@ def test_resize_embeddings_untied(self): | |
| continue | ||
|
|
||
| # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size | ||
| model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size | ||
| model_vocab_size = config.get_text_config().vocab_size | ||
| model.resize_token_embeddings(model_vocab_size + 10) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
| self.assertEqual(new_model_vocab_size, model_vocab_size + 10) | ||
| output_embeds = model.get_output_embeddings() | ||
| self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) | ||
|
|
@@ -1892,11 +1871,7 @@ def test_resize_embeddings_untied(self): | |
|
|
||
| # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size | ||
| model.resize_token_embeddings(model_vocab_size - 15) | ||
| new_model_vocab_size = ( | ||
| model.config.text_config.vocab_size | ||
| if hasattr(model.config, "text_config") | ||
| else model.config.vocab_size | ||
| ) | ||
| new_model_vocab_size = model.config.get_text_config().vocab_size | ||
| self.assertEqual(new_model_vocab_size, model_vocab_size - 15) | ||
| # Check that it actually resizes the embeddings matrix | ||
| output_embeds = model.get_output_embeddings() | ||
|
|
@@ -1988,7 +1963,7 @@ def check_same_values(layer_1, layer_2): | |
| # self.assertTrue(check_same_values(embeddings, decoding)) | ||
|
|
||
| # Check that after resize they remain tied. | ||
| vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size | ||
| vocab_size = config.get_text_config().vocab_size | ||
| model_tied.resize_token_embeddings(vocab_size + 10) | ||
| params_tied_2 = list(model_tied.parameters()) | ||
| self.assertEqual(len(params_tied_2), len(params_tied)) | ||
|
|
@@ -4831,7 +4806,7 @@ def test_forward_with_num_logits_to_keep(self): | |
|
|
||
| config, inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||
| batch_size, sequence_length = inputs["input_ids"].shape | ||
| vocab_size = config.vocab_size | ||
| vocab_size = config.get_text_config().vocab_size | ||
| model = model_class(config).to(device=torch_device).eval() | ||
|
|
||
| # num_logits_to_keep=0 is a special case meaning "keep all logits" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor): | |
| # Avoid `IndexError` in embedding layers | ||
| CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"] | ||
| if tokenizer is not None: | ||
| config_vocab_size = getattr(model.config, "vocab_size", None) | ||
| # Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, we can add a flag to a) return the first valid match; OR b) return all matches when there is more than one match. |
||
| config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None) | ||
| # For CLIP-like models | ||
| if config_vocab_size is None: | ||
| if hasattr(model.config, "text_config"): | ||
| if hasattr(model.config, "text_encoder"): | ||
| config_vocab_size = getattr(model.config.text_config, "vocab_size", None) | ||
| elif hasattr(model.config, "text_encoder"): | ||
| config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None) | ||
|
|
||
| if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE: | ||
| raise ValueError( | ||
| "Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`." | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.