@@ -1747,12 +1747,13 @@ def test_resize_position_vector_embeddings(self):
17471747 self .assertTrue (models_equal )
17481748
17491749 def test_resize_tokens_embeddings (self ):
1750+ if not self .test_resize_embeddings :
1751+ self .skipTest (reason = "test_resize_embeddings is set to `False`" )
1752+
17501753 (
17511754 original_config ,
17521755 inputs_dict ,
17531756 ) = self .model_tester .prepare_config_and_inputs_for_common ()
1754- if not self .test_resize_embeddings :
1755- self .skipTest (reason = "test_resize_embeddings is set to `False`" )
17561757
17571758 for model_class in self .all_model_classes :
17581759 config = copy .deepcopy (original_config )
@@ -1764,18 +1765,15 @@ def test_resize_tokens_embeddings(self):
17641765 if self .model_tester .is_training is False :
17651766 model .eval ()
17661767
1767- model_vocab_size = config .text_config . vocab_size if hasattr ( config , "text_config" ) else config .vocab_size
1768+ model_vocab_size = config .get_text_config () .vocab_size
17681769 # Retrieve the embeddings and clone theme
17691770 model_embed = model .resize_token_embeddings (model_vocab_size )
17701771 cloned_embeddings = model_embed .weight .clone ()
17711772
17721773 # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
17731774 model_embed = model .resize_token_embeddings (model_vocab_size + 10 )
1774- new_model_vocab_size = (
1775- model .config .text_config .vocab_size
1776- if hasattr (model .config , "text_config" )
1777- else model .config .vocab_size
1778- )
1775+ new_model_vocab_size = model .config .get_text_config ().vocab_size
1776+
17791777 self .assertEqual (new_model_vocab_size , model_vocab_size + 10 )
17801778 # Check that it actually resizes the embeddings matrix
17811779 self .assertEqual (model_embed .weight .shape [0 ], cloned_embeddings .shape [0 ] + 10 )
@@ -1787,11 +1785,7 @@ def test_resize_tokens_embeddings(self):
17871785
17881786 # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
17891787 model_embed = model .resize_token_embeddings (model_vocab_size - 15 )
1790- new_model_vocab_size = (
1791- model .config .text_config .vocab_size
1792- if hasattr (model .config , "text_config" )
1793- else model .config .vocab_size
1794- )
1788+ new_model_vocab_size = model .config .get_text_config ().vocab_size
17951789 self .assertEqual (new_model_vocab_size , model_vocab_size - 15 )
17961790 # Check that it actually resizes the embeddings matrix
17971791 self .assertEqual (model_embed .weight .shape [0 ], cloned_embeddings .shape [0 ] - 15 )
@@ -1817,21 +1811,13 @@ def test_resize_tokens_embeddings(self):
18171811 model = model_class (config )
18181812 model .to (torch_device )
18191813
1820- model_vocab_size = config .text_config . vocab_size if hasattr ( config , "text_config" ) else config .vocab_size
1814+ model_vocab_size = config .get_text_config () .vocab_size
18211815 model .resize_token_embeddings (model_vocab_size + 10 , pad_to_multiple_of = 1 )
1822- new_model_vocab_size = (
1823- model .config .text_config .vocab_size
1824- if hasattr (model .config , "text_config" )
1825- else model .config .vocab_size
1826- )
1816+ new_model_vocab_size = model .config .get_text_config ().vocab_size
18271817 self .assertTrue (new_model_vocab_size + 10 , model_vocab_size )
18281818
18291819 model_embed = model .resize_token_embeddings (model_vocab_size , pad_to_multiple_of = 64 )
1830- new_model_vocab_size = (
1831- model .config .text_config .vocab_size
1832- if hasattr (model .config , "text_config" )
1833- else model .config .vocab_size
1834- )
1820+ new_model_vocab_size = model .config .get_text_config ().vocab_size
18351821 self .assertTrue (model_embed .weight .shape [0 ] // 64 , 0 )
18361822
18371823 self .assertTrue (model_embed .weight .shape [0 ], new_model_vocab_size )
@@ -1852,13 +1838,10 @@ def test_resize_tokens_embeddings(self):
18521838 model .resize_token_embeddings (model_vocab_size , pad_to_multiple_of = 1.3 )
18531839
18541840 def test_resize_embeddings_untied (self ):
1855- (
1856- original_config ,
1857- inputs_dict ,
1858- ) = self .model_tester .prepare_config_and_inputs_for_common ()
18591841 if not self .test_resize_embeddings :
18601842 self .skipTest (reason = "test_resize_embeddings is set to `False`" )
18611843
1844+ original_config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
18621845 original_config .tie_word_embeddings = False
18631846
18641847 # if model cannot untied embeddings -> leave test
@@ -1874,13 +1857,9 @@ def test_resize_embeddings_untied(self):
18741857 continue
18751858
18761859 # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
1877- model_vocab_size = config .text_config . vocab_size if hasattr ( config , "text_config" ) else config .vocab_size
1860+ model_vocab_size = config .get_text_config () .vocab_size
18781861 model .resize_token_embeddings (model_vocab_size + 10 )
1879- new_model_vocab_size = (
1880- model .config .text_config .vocab_size
1881- if hasattr (model .config , "text_config" )
1882- else model .config .vocab_size
1883- )
1862+ new_model_vocab_size = model .config .get_text_config ().vocab_size
18841863 self .assertEqual (new_model_vocab_size , model_vocab_size + 10 )
18851864 output_embeds = model .get_output_embeddings ()
18861865 self .assertEqual (output_embeds .weight .shape [0 ], model_vocab_size + 10 )
@@ -1892,11 +1871,7 @@ def test_resize_embeddings_untied(self):
18921871
18931872 # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
18941873 model .resize_token_embeddings (model_vocab_size - 15 )
1895- new_model_vocab_size = (
1896- model .config .text_config .vocab_size
1897- if hasattr (model .config , "text_config" )
1898- else model .config .vocab_size
1899- )
1874+ new_model_vocab_size = model .config .get_text_config ().vocab_size
19001875 self .assertEqual (new_model_vocab_size , model_vocab_size - 15 )
19011876 # Check that it actually resizes the embeddings matrix
19021877 output_embeds = model .get_output_embeddings ()
@@ -1988,7 +1963,7 @@ def check_same_values(layer_1, layer_2):
19881963 # self.assertTrue(check_same_values(embeddings, decoding))
19891964
19901965 # Check that after resize they remain tied.
1991- vocab_size = config .text_config . vocab_size if hasattr ( config , "text_config" ) else config .vocab_size
1966+ vocab_size = config .get_text_config () .vocab_size
19921967 model_tied .resize_token_embeddings (vocab_size + 10 )
19931968 params_tied_2 = list (model_tied .parameters ())
19941969 self .assertEqual (len (params_tied_2 ), len (params_tied ))
@@ -4831,7 +4806,7 @@ def test_forward_with_num_logits_to_keep(self):
48314806
48324807 config , inputs = self .model_tester .prepare_config_and_inputs_for_common ()
48334808 batch_size , sequence_length = inputs ["input_ids" ].shape
4834- vocab_size = config .vocab_size
4809+ vocab_size = config .get_text_config (). vocab_size
48354810 model = model_class (config ).to (device = torch_device ).eval ()
48364811
48374812 # num_logits_to_keep=0 is a special case meaning "keep all logits"
0 commit comments