Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>

</hfoption>
<hfoption id="int4-weight-only">

Expand Down Expand Up @@ -332,6 +330,7 @@ quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serializatio
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
```
</hfoption>
</hfoptions>


## Loading quantized models
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,5 +1045,20 @@ def __init__(

super().__init__(**kwargs)

@classmethod
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.

Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()


__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -2503,7 +2503,9 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
)

if not return_dict:
output = (logits,) + outputs
Expand Down Expand Up @@ -4384,6 +4386,7 @@ def __init__(self, config):
self.speaker_map = {}
if config.enable_audio_output:
self.enable_talker()
self.post_init()

def enable_talker(self):
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,21 @@ def __init__(

super().__init__(**kwargs)

@classmethod
def get_text_config(self, decoder=False) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.

Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
# Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()


class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
config_class = Qwen2_5OmniConfig
Expand Down Expand Up @@ -2463,7 +2478,9 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size
)

if not return_dict:
output = (logits,) + outputs
Expand Down Expand Up @@ -4053,6 +4070,7 @@ def __init__(self, config):
self.speaker_map = {}
if config.enable_audio_output:
self.enable_talker()
self.post_init()

def enable_talker(self):
self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config)
Expand Down