Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,6 @@ def to_diff_dict(self) -> dict[str, Any]:
isinstance(getattr(self, key, None), PreTrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
or key in self.sub_configs
):
# For nested configs we need to clean the diff recursively
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
Expand Down
66 changes: 33 additions & 33 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,13 +1187,13 @@ def _get_dtype(
dtype = getattr(torch, dtype)
config.dtype = dtype
for sub_config_key in config.sub_configs:
sub_config = getattr(config, sub_config_key)
sub_config.dtype = dtype
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
Comment on lines -1190 to +1191
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is needed only if super().__init__() is called before subconfigs are set in the config no? Let's quickly move the call to the bottom on all changed configs if so instead of this!

Copy link
Member Author

@zucchini-nlp zucchini-nlp Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not only, some models do not initialize a sub config at run time and either 1) infer when modeling in load_backbone from a variety of config keys 2) do not need the path with sub-model so they don't need a config for that

These are all vision models with backbones, and maybe we can update them to follow the common standards with nested configs. I would prefer to check and do standardization much later, after making and finishing the config type validation PR. So I'll merge it for now and take note to myself to come back and dig into backbone-vision configs :)

elif isinstance(dtype, torch.dtype):
config.dtype = dtype
for sub_config_key in config.sub_configs:
sub_config = getattr(config, sub_config_key)
sub_config.dtype = dtype
if (sub_config := getattr(config, sub_config_key)) is not None:
sub_config.dtype = dtype
elif isinstance(dtype, dict):
for key, curr_dtype in dtype.items():
if hasattr(config, key):
Expand All @@ -1218,8 +1218,8 @@ def _get_dtype(
default_dtype = torch.get_default_dtype()
config.dtype = default_dtype
for key in config.sub_configs:
value = getattr(config, key)
value.dtype = default_dtype
if (sub_config := getattr(config, key)) is not None:
sub_config.dtype = default_dtype

return config, dtype, dtype_orig

Expand Down Expand Up @@ -2673,34 +2673,34 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):

# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
for subconfig_key in self.config.sub_configs:
subconfig = getattr(self.config, subconfig_key)
sub_implementation = (
requested_implementation
if not isinstance(attn_implementation, dict)
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
)
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
if (
not hasattr(subconfig, "_attn_was_changed")
# If it's already the same, then no need to enter here and raise warnings
and sub_implementation != subconfig._attn_implementation
):
if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
raise ValueError(
f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
'The only possible arguments are "eager" (manual attention implementation)'
f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
)
subconfig._attn_implementation_internal = sub_implementation
logger.warning(
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
"You may encounter undefined behavior."
if (subconfig := getattr(self.config, subconfig_key)) is not None:
sub_implementation = (
requested_implementation
if not isinstance(attn_implementation, dict)
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
)
# Unset the attribute in this case, to avoid issues in the future
else:
if hasattr(subconfig, "_attn_was_changed"):
del subconfig._attn_was_changed
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
if (
not hasattr(subconfig, "_attn_was_changed")
# If it's already the same, then no need to enter here and raise warnings
and sub_implementation != subconfig._attn_implementation
):
if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
raise ValueError(
f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
'The only possible arguments are "eager" (manual attention implementation)'
f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
)
subconfig._attn_implementation_internal = sub_implementation
logger.warning(
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
"You may encounter undefined behavior."
)
# Unset the attribute in this case, to avoid issues in the future
else:
if hasattr(subconfig, "_attn_was_changed"):
del subconfig._attn_was_changed

def enable_input_require_grads(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...onnx import OnnxConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -135,6 +135,7 @@ class ConditionalDetrConfig(PreTrainedConfig):
```"""

model_type = "conditional_detr"
sub_configs = {"backbone_config": AutoConfig}
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "d_model",
Expand Down Expand Up @@ -245,22 +246,6 @@ def __init__(
self.focal_alpha = focal_alpha
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)


class ConditionalDetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
Expand Down
19 changes: 2 additions & 17 deletions src/transformers/models/d_fine/configuration_d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...configuration_utils import PreTrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -194,6 +194,7 @@ class DFineConfig(PreTrainedConfig):
"""

model_type = "d_fine"
sub_configs = {"backbone_config": AutoConfig}
layer_types = ["basic", "bottleneck"]
attribute_map = {
"hidden_size": "d_model",
Expand Down Expand Up @@ -396,22 +397,6 @@ def __init__(
)
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)

@classmethod
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
Expand Down
19 changes: 2 additions & 17 deletions src/transformers/models/d_fine/modular_d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...image_transforms import corners_to_center_format
from ...utils import is_torchdynamo_compiling, logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig
from ..rt_detr.modeling_rt_detr import (
RTDetrConvNormLayer,
RTDetrDecoder,
Expand Down Expand Up @@ -213,6 +213,7 @@ class DFineConfig(PreTrainedConfig):
"""

model_type = "d_fine"
sub_configs = {"backbone_config": AutoConfig}
layer_types = ["basic", "bottleneck"]
attribute_map = {
"hidden_size": "d_model",
Expand Down Expand Up @@ -415,22 +416,6 @@ def __init__(
)
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)

@classmethod
def from_backbone_configs(cls, backbone_config: PreTrainedConfig, **kwargs):
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
Expand Down
11 changes: 2 additions & 9 deletions src/transformers/models/dab_detr/configuration_dab_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...configuration_utils import PreTrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -136,6 +136,7 @@ class DabDetrConfig(PreTrainedConfig):
```"""

model_type = "dab-detr"
sub_configs = {"backbone_config": AutoConfig}
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
Expand Down Expand Up @@ -256,13 +257,5 @@ def __init__(
self.initializer_bias_prior_prob = initializer_bias_prior_prob
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)


__all__ = ["DabDetrConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...configuration_utils import PreTrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -144,6 +144,7 @@ class DeformableDetrConfig(PreTrainedConfig):
```"""

model_type = "deformable_detr"
sub_configs = {"backbone_config": AutoConfig}
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "encoder_attention_heads",
Expand Down Expand Up @@ -270,21 +271,5 @@ def __init__(
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)


__all__ = ["DeformableDetrConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
# limitations under the License.
"""DepthAnything model configuration"""

import copy

from ...configuration_utils import PreTrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -89,6 +87,7 @@ class DepthAnythingConfig(PreTrainedConfig):
```"""

model_type = "depth_anything"
sub_configs = {"backbone_config": AutoConfig}

def __init__(
self,
Expand Down Expand Up @@ -151,26 +150,5 @@ def __init__(
self.depth_estimation_type = depth_estimation_type
self.max_depth = max_depth if max_depth else 1

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PreTrainedConfig.to_dict`]. Returns:
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)

if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()

output["model_type"] = self.__class__.model_type
return output


__all__ = ["DepthAnythingConfig"]
19 changes: 2 additions & 17 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...onnx import OnnxConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from ..auto import CONFIG_MAPPING, AutoConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -133,6 +133,7 @@ class DetrConfig(PreTrainedConfig):
```"""

model_type = "detr"
sub_configs = {"backbone_config": AutoConfig}
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "d_model",
Expand Down Expand Up @@ -244,22 +245,6 @@ def __init__(
self.eos_coefficient = eos_coefficient
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model

@property
def sub_configs(self):
return (
{"backbone_config": type(self.backbone_config)}
if getattr(self, "backbone_config", None) is not None
else {}
)

@classmethod
def from_backbone_config(cls, backbone_config: PreTrainedConfig, **kwargs):
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
Expand Down
Loading