|  | 
|  | 1 | +import tempfile | 
|  | 2 | +import unittest | 
|  | 3 | + | 
|  | 4 | +from transformers import LlavaConfig | 
|  | 5 | + | 
|  | 6 | + | 
|  | 7 | +class LlavaConfigTest(unittest.TestCase): | 
|  | 8 | +    def test_llava_reload(self): | 
|  | 9 | +        """ | 
|  | 10 | +        Simple test for reloading default llava configs | 
|  | 11 | +        """ | 
|  | 12 | +        with tempfile.TemporaryDirectory() as tmp_dir: | 
|  | 13 | +            config = LlavaConfig() | 
|  | 14 | +            config.save_pretrained(tmp_dir) | 
|  | 15 | + | 
|  | 16 | +            reloaded = LlavaConfig.from_pretrained(tmp_dir) | 
|  | 17 | +            assert config.to_dict() == reloaded.to_dict() | 
|  | 18 | + | 
|  | 19 | +    def test_pixtral_reload(self): | 
|  | 20 | +        """ | 
|  | 21 | +        Simple test for reloading pixtral configs | 
|  | 22 | +        """ | 
|  | 23 | +        vision_config = { | 
|  | 24 | +            "model_type": "pixtral", | 
|  | 25 | +            "head_dim": 64, | 
|  | 26 | +            "hidden_act": "silu", | 
|  | 27 | +            "image_size": 1024, | 
|  | 28 | +            "is_composition": True, | 
|  | 29 | +            "patch_size": 16, | 
|  | 30 | +            "rope_theta": 10000.0, | 
|  | 31 | +            "tie_word_embeddings": False, | 
|  | 32 | +        } | 
|  | 33 | + | 
|  | 34 | +        text_config = { | 
|  | 35 | +            "model_type": "mistral", | 
|  | 36 | +            "hidden_size": 5120, | 
|  | 37 | +            "head_dim": 128, | 
|  | 38 | +            "num_attention_heads": 32, | 
|  | 39 | +            "intermediate_size": 14336, | 
|  | 40 | +            "is_composition": True, | 
|  | 41 | +            "max_position_embeddings": 1024000, | 
|  | 42 | +            "num_hidden_layers": 40, | 
|  | 43 | +            "num_key_value_heads": 8, | 
|  | 44 | +            "rms_norm_eps": 1e-05, | 
|  | 45 | +            "rope_theta": 1000000000.0, | 
|  | 46 | +            "sliding_window": None, | 
|  | 47 | +            "vocab_size": 131072, | 
|  | 48 | +        } | 
|  | 49 | + | 
|  | 50 | +        with tempfile.TemporaryDirectory() as tmp_dir: | 
|  | 51 | +            config = LlavaConfig(vision_config=vision_config, text_config=text_config) | 
|  | 52 | +            config.save_pretrained(tmp_dir) | 
|  | 53 | + | 
|  | 54 | +            reloaded = LlavaConfig.from_pretrained(tmp_dir) | 
|  | 55 | +            assert config.to_dict() == reloaded.to_dict() | 
|  | 56 | + | 
|  | 57 | +    def test_arbitrary_reload(self): | 
|  | 58 | +        """ | 
|  | 59 | +        Simple test for reloading arbirarily composed subconfigs | 
|  | 60 | +        """ | 
|  | 61 | +        default_values = LlavaConfig().to_dict() | 
|  | 62 | +        default_values["vision_config"]["model_type"] = "qwen2_vl" | 
|  | 63 | +        default_values["text_config"]["model_type"] = "opt" | 
|  | 64 | + | 
|  | 65 | +        with tempfile.TemporaryDirectory() as tmp_dir: | 
|  | 66 | +            config = LlavaConfig(**default_values) | 
|  | 67 | +            config.save_pretrained(tmp_dir) | 
|  | 68 | + | 
|  | 69 | +            reloaded = LlavaConfig.from_pretrained(tmp_dir) | 
|  | 70 | +            assert config.to_dict() == reloaded.to_dict() | 
0 commit comments