Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,40 @@
SpecificPretrainedConfigType = TypeVar("SpecificPretrainedConfigType", bound="PretrainedConfig")


# Map containing deprecated/deleted model types, and the first version that does NOT support them.
# - set version > current version: the model type is deprecated, and users will see a warning;
# - set version <= current version: the model type is deleted, and users will see an exception pointing to the last
# version that supported it.
# NOTE: this variable is set here (and not in `models.auto`) to avoid circular imports. We want to use it with
# `PretrainedConfig` to make sure the deprecation warning is seen even if the user doesn't use auto classes.
PREVIOUSLY_SUPPORTED_MODELS_TYPES = {
"bort": "5.0.0",
"deta": "5.0.0",
"efficientformer": "5.0.0",
"ernie_m": "5.0.0",
"gptsan-japanese": "5.0.0",
"graphormer": "5.0.0",
"jukebox": "5.0.0",
"mctct": "5.0.0",
"mega": "5.0.0",
"mmbt": "5.0.0",
"nat": "5.0.0",
"nezha": "5.0.0",
"open-llama": "5.0.0",
"qdqbert": "5.0.0",
"realm": "5.0.0",
"retribert": "5.0.0",
"speech_to_text_2": "5.0.0",
"tapex": "5.0.0",
"trajectory_transformer": "5.0.0",
"transfo-xl": "5.0.0",
"tvlt": "5.0.0",
"van": "5.0.0",
"vit_hybrid": "5.0.0",
"xlm-prophetnet": "5.0.0",
}


class PretrainedConfig(PushToHubMixin):
# no-format
r"""
Expand Down Expand Up @@ -344,6 +378,17 @@ def __init__(
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

# Handle deprecations: if the model type is deprecated, we raise a warning
# (if this line is reached and the model type is in PREVIOUSLY_SUPPORTED_MODELS_TYPES, it means the model
# config class is still operational -- it's deprecated, not deleted)
if self.model_type in PREVIOUSLY_SUPPORTED_MODELS_TYPES:
last_version = PREVIOUSLY_SUPPORTED_MODELS_TYPES[self.model_type]
warnings.warn(
f"\n🚨🚨 The model type `{self.model_type}` is deprecated and will be removed from `transformers` in "
f"v{last_version}. If you want to continue using this model, make sure to pin `transformers` to "
f"a version older than v{last_version}. 🚨🚨\n"
)

def _create_id_label_maps(self, num_labels: int):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
Expand Down
58 changes: 26 additions & 32 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from collections.abc import Callable, Iterator, KeysView, ValuesView
from typing import Any, TypeVar, Union

from ...configuration_utils import PretrainedConfig
from packaging import version

from ... import __version__ as transformers_version
from ...configuration_utils import PREVIOUSLY_SUPPORTED_MODELS_TYPES, PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, logging

Expand Down Expand Up @@ -923,34 +926,6 @@
]
)

# This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
# `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
DEPRECATED_MODELS = [
"bort",
"deta",
"efficientformer",
"ernie_m",
"gptsan_japanese",
"graphormer",
"jukebox",
"mctct",
"mega",
"mmbt",
"nat",
"nezha",
"open_llama",
"qdqbert",
"realm",
"retribert",
"speech_to_text_2",
"tapex",
"trajectory_transformer",
"transfo_xl",
"tvlt",
"van",
"vit_hybrid",
"xlm_prophetnet",
]

SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
[
Expand Down Expand Up @@ -1001,6 +976,13 @@
]
)

# List of modules containing deprecated models, built from `PREVIOUSLY_SUPPORTED_MODELS_TYPES` and applying the
# mapping from model type on the hub to module name in `transformers`
DEPRECATED_MODELS = [
getattr(SPECIAL_MODEL_TYPE_TO_MODULE_NAME, key, key).replace("-", "_")
for key in PREVIOUSLY_SUPPORTED_MODELS_TYPES.keys()
]


def model_type_to_module_name(key) -> str:
"""Converts a config key to the corresponding module."""
Expand All @@ -1012,7 +994,7 @@ def model_type_to_module_name(key) -> str:
key = f"deprecated.{key}"
return key

key = key.replace("-", "_")
key = key.replace("-", "_") # folders in transformers always use `_` instead of `-`, regarless of the model type
if key in DEPRECATED_MODELS:
key = f"deprecated.{key}"

Expand Down Expand Up @@ -1349,10 +1331,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[s
config_class.register_for_auto_class()
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
# Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
# Handle past deletions
if config_dict["model_type"] in PREVIOUSLY_SUPPORTED_MODELS_TYPES:
current_version = version.parse(transformers_version)
last_version = version.parse(PREVIOUSLY_SUPPORTED_MODELS_TYPES[config_dict["model_type"]])
if current_version >= last_version:
raise ValueError(
f"The model type `{config_dict['model_type']}` was removed from `transformers` in "
f"v{last_version}. To use this model, make sure to install a corresponding version of "
"`transformers`"
)

# Mistral heuristic: if model_type is mistral but layer_types is present, treat as ministral
if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
logger.info(
"Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
"Detected mistral model with layer_types, treating as ministral for alternating attention "
"compatibility. "
)
config_dict["model_type"] = "ministral"

Expand Down
32 changes: 31 additions & 1 deletion tests/utils/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import httpx

from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config, GPTSanJapaneseConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_torch

Expand Down Expand Up @@ -356,3 +356,33 @@ def test_bc_torch_dtype(self):

config = PretrainedConfig.from_pretrained(tmpdirname, torch_dtype="float32")
self.assertEqual(config.dtype, "float32")

def test_deleted_model_type(self):
# TODO, when we have actually removed a model class:
# - let's use that model as a test case, instead of mocking fake versions.
# - let's also test trying to load the deleted class, without AutoConfig
with mock.patch(
"transformers.models.auto.configuration_auto.PREVIOUSLY_SUPPORTED_MODELS_TYPES",
{"gptsan-japanese": "4.0.0"},
):
with self.assertRaises(ValueError):
AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTSanJapaneseForConditionalGeneration")

def test_deprecated_model_type(self):
# This is intentionally a live test, testing against actually deprecated (but not yet deleted) model classes.
# We want to be highly confident that users are seeing the deprecation warnings.
with self.assertWarns(UserWarning) as cm:
AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTSanJapaneseForConditionalGeneration")
self.assertIn("gptsan-japanese", str(cm.warning))
self.assertIn("v5.0.0", str(cm.warning))
self.assertIn("transformers", str(cm.warning))
self.assertIn("🚨", str(cm.warning))

with self.assertWarns(UserWarning) as cm:
GPTSanJapaneseConfig.from_pretrained(
"hf-internal-testing/tiny-random-GPTSanJapaneseForConditionalGeneration"
)
self.assertIn("gptsan-japanese", str(cm.warning))
self.assertIn("v5.0.0", str(cm.warning))
self.assertIn("transformers", str(cm.warning))
self.assertIn("🚨", str(cm.warning))