Skip to content

Conversation

@kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Feb 7, 2025

Purpose

Related issues

Background

Many issues arise when attempting to save and load Pixtral configs, namely that the pixtral model is saved with text_config.head_dim=128, but upon loading receives the value text_config.head_dim=160.

Through further inspection, it was found that this was due to an issue where the head_dim

MistralConfig

hidden_size=4096,
num_attention_heads=32,
head_dim = head_dim or hidden_size // num_attention_heads = 128

Pixtral 12B

hidden_size=5120,
head_dim=128,
num_attention_heads = None

When transformers configs are saved, a diff between the saving config and the default config is computed in order to reduce verbosity. When the Pixtral 12B config saves the head_dim attribute, it checks against the MistralConfig and finds that both values are the same (128), and therefore does not save the head_dim to disk.

However, when loading the same config using the MistralConfig class, the class recomputes head_dim using the saved hidden_size=5120 (which differs from the default 4096), and returns the value head_dim=160. This incorrect calculation causes the model to be loaded incorrectly.

Note that for the pixtral model, head_dim != hidden_size // num_attention_heads, so this issue cannot be remedied by populating the num_attention_heads attribute of the pixtral config.

While I consider creating a dedicated PixtralTextConfig to be the most "correct" solution which reflects the differences in assumptions made by the pixtral and mistral models, implementing a config class in transformers which does not have a unique model associated with it requires additional work and testing. This solution is backwards compatible with existing pixtral configs and does not require intensive changes to configs.

Changes

  • Set is_composition=True for LlavaConfig, which forces all subconfig attributes to be saved to disk.
    • This means that head_dim is loaded and overrides the non-applicable calculation in MistralConfig\
  • Add tests for reloading llava configs

Potential Follow-ups

Testing

from transformers import LlavaConfig

config = LlavaConfig.from_pretrained("mistral-community/pixtral-12b")
assert config.text_config.head_dim == 128

config.save_pretrained("tmp")
config = LlavaConfig.from_pretrained("tmp")
assert config.text_config.head_dim == 128  # previously, this would be overwritten by MistralConfig to be 160

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for finding the root cause. However I don't think we can solve this by just setting the flag. This flag is to be used for a bit different purposes and the naming a misleading

I think we should find why the saved/loaded config inferred incorrect LM-config type and fix it the way, so that we can infer correct lm-config. And a small test would be nice to show that composite configs can be saved/loaded with different sub-components even when attributes are shared with default values

@kylesayrs
Copy link
Contributor Author

kylesayrs commented Feb 7, 2025

Hi @zucchini-nlp!

I think we should find why the saved/loaded config inferred incorrect LM-config type and fix it the way, so that we can infer correct lm-config

This would likely involve having to modify the PretrainedConfig base class to initialize using the LlavaConfig scheme. Specifically,

class_config_dict = self.__class__().to_dict() if not self.is_composition else {}

Calls the init with no arguments, which causes LlavaConfig to default to a Llama text config

if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()

This is an inevitable problem because LlavaConfig must have some default if no arguments are provided. Changing this default to MistralConfig will break all configs which default to LlamaConfig.

Your comment suggests modifying PretrainedConfig.to_diff_dict to specialize with init args. Maybe something like

init_args = {key: getattr(value, "model_type", None) for key, value in self.sub_configs}
class_config_dict = self.__class__(**init_args).to_dict() if not self.is_composition else {}

The only problem with this implementation is that this forces all configs to follow the particular scheme of using the sub_configs attribute and specializing using the model_type subattribute and nothing else. This would mean having to fix configs which do not follow this scheme like HybridClipConfig and likely a large suite of tests across all configs.

@kylesayrs
Copy link
Contributor Author

kylesayrs commented Feb 7, 2025

Hey, thanks for finding the root cause. However I don't think we can solve this by just setting the flag. This flag is to be used for a bit different purposes and the naming a misleading

@zucchini-nlp Could you provide a little more context as to how this flag is used? Grepping the transformers codebase reveals that this attribute is only consumed by PretrainedConfig.to_diff_dict and is set to True for many nested configs similar to LlavaConfig.
https://github.com/search?q=repo%3Ahuggingface%2Ftransformers%20is_composition&type=code

There may be some upstream use that I am unaware of.

@zucchini-nlp
Copy link
Member

Thanks, now I see why the model types are not matching. The flag was introduces at first for models like Musicgen, when the sub-configs had no default values set and had to be indicated by users explicitly (in #25237).

def __init__(self, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")

For Llava we didn't set it, because we have default values. But now since the same config class is used with various LM backbones, the defaults don't make much sense. I agree with the proposed fix and imo we'll need the flag to be added in all confgis that accept any model_type for sub-config.

Can you propagate it to all VLMs like Llava and add a small test?

@kylesayrs
Copy link
Contributor Author

kylesayrs commented Feb 9, 2025

WIP comment: The cause of this issue is more subtle than I initially thought. There already exist mechanisms for catching diffs of non-default subconfigs that I initially missed, but they seem to not work well with computed config values.

diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))

elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:

Mistral has default values

hidden_size=4096,
num_attention_heads=32,
self.head_dim = head_dim or hidden_size // num_attention_heads = 128

Pixtral has default values

hidden_size=5120,
head_dim=128,
num_attention_heads = None

Because the default head dim values match, the diff does not register and the head_dim is not written. But upon reloading, head_dim is recomputed using the new hidden_size, which results in the 160 value.

This maybe fixable by simply adding num_attention_heads=40 to the pixtral config

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 10, 2025

@kylesayrs yeah, that's what I thought at first but the Pixtral model checkpoint actually has num_attention_heads=32. So it wouldn't work in that case. The reason is that for Pixtral head_dim * num_attention_heads != hidden_size

Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Contributor Author

kylesayrs commented Feb 10, 2025

@zucchini-nlp I've updated with a WIP of what a pixtral text config might look like. Let me know what you think. If this approach isn't viable, we can fall back to is_composition

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me as a workaround for Pixtral case. But I think the is_composition will be a more generalizable solution, given that we cannot create sub-configs for every vision lm which clashed with default llama

Making LlavaConfig work with any backbone seems better to me

@kylesayrs
Copy link
Contributor Author

kylesayrs commented Feb 11, 2025

@zucchini-nlp I think the is_composition solution works, but it creates very verbose configs which may be unnecessary. To me, the problem is that the MistralConfig does not apply to the pixtral model, as the Mistral assumption that head_dim = hidden_size // num_attention_head is broken.

I still think that using is_composition is a good solution for now, as it ensures backwards compatibility with existing pixtral configs and doesn't involve creating a new "pixtral_text" config which does not correspond to a unique model (which afaict requires extra work to support with transformers)

@kylesayrs
Copy link
Contributor Author

@zucchini-nlp I've updated the PR with a clearer explanation of the issue and links to potential follow ups.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Indeed the flag here is not the best way, and we need to consider refactoring how nested configs are saved/loaded in subsequent PRs. But this will work as a solution for Pixtral

"""
vision_config = {
"model_type": "pixtral",
"head_dim": 64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this to be 128 to trigger the same error as currently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually the head_dim on the text_config that triggers the issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, my bad, didn't notice this was vision

@zucchini-nlp
Copy link
Member

cc @ArthurZucker if you want to give a look

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

Will merge it, since it's not core and nice to have in the next release

@zucchini-nlp zucchini-nlp merged commit bcfc9d7 into huggingface:main Feb 14, 2025
11 checks passed
@kylesayrs kylesayrs deleted the redhat/fix-llava-config branch February 15, 2025 17:21
ydshieh added a commit that referenced this pull request Feb 17, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
* add is_composition flag to LlavaConfig

Signed-off-by: Kyle Sayers <[email protected]>

* WIP: pixtral text config

Signed-off-by: Kyle Sayers <[email protected]>

* fix style

Signed-off-by: Kyle Sayers <[email protected]>

* add test

Signed-off-by: Kyle Sayers <[email protected]>

* use is_composition for pixtral

Signed-off-by: Kyle Sayers <[email protected]>

* Revert "use is_composition for pixtral"

This reverts commit a53d5f9.

* Revert "Revert "use is_composition for pixtral""

This reverts commit 3ab1c99.

---------

Signed-off-by: Kyle Sayers <[email protected]>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants