diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 312cbe2ca8b8..f566a4667d00 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -33,7 +33,6 @@ from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ..pytorch_utils import ( is_torch_greater_or_equal, - is_torch_greater_or_equal_than_2_3, is_torch_greater_or_equal_than_2_6, ) @@ -765,8 +764,6 @@ def convert_and_export_with_cache( Returns: Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. """ - if not is_torch_greater_or_equal_than_2_3: - raise ImportError("torch >= 2.3 is required.") import torch.export._trace @@ -1035,8 +1032,6 @@ def export_with_dynamic_cache( Returns: Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. """ - if not is_torch_greater_or_equal_than_2_3: - raise ImportError("torch >= 2.3 is required.") # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7ab90f1433fc..5b34cc221a0a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -372,23 +372,20 @@ def get_state_dict_dtype(state_dict): "U8": torch.uint8, "I8": torch.int8, "I16": torch.int16, + "U16": torch.uint16, "F16": torch.float16, "BF16": torch.bfloat16, "I32": torch.int32, + "U32": torch.uint32, "F32": torch.float32, "F64": torch.float64, "I64": torch.int64, + "U64": torch.uint64, "F8_E4M3": torch.float8_e4m3fn, "F8_E5M2": torch.float8_e5m2, } -if is_torch_greater_or_equal("2.3.0"): - str_to_torch_dtype["U16"] = torch.uint16 - str_to_torch_dtype["U32"] = torch.uint32 - str_to_torch_dtype["U64"] = torch.uint64 - - def load_state_dict( checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False, diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 8086fb1e2e98..c8745768238b 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -36,9 +36,9 @@ is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True) is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) -is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) # For backwards compatibility (e.g. some remote codes on Hub using those variables). +is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True) is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True) is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 8e85209abc9a..93ba0aadef69 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -91,9 +91,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ @lru_cache def is_torch_available() -> bool: is_available, torch_version = _is_package_available("torch", return_version=True) - if is_available and version.parse(torch_version) < version.parse("2.2.0"): - logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.2 is required but found {torch_version}") - return is_available and version.parse(torch_version) >= version.parse("2.2.0") + if is_available and version.parse(torch_version) < version.parse("2.3.0"): + logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.3 is required but found {torch_version}") + return is_available and version.parse(torch_version) >= version.parse("2.3.0") @lru_cache diff --git a/tests/test_executorch.py b/tests/test_executorch.py index 0e33253c08f1..e41a91435dc3 100644 --- a/tests/test_executorch.py +++ b/tests/test_executorch.py @@ -23,16 +23,12 @@ TorchExportableModuleWithHybridCache, TorchExportableModuleWithStaticCache, ) -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 from transformers.testing_utils import require_torch @require_torch class ExecutorchTest(unittest.TestCase): def setUp(self): - if not is_torch_greater_or_equal_than_2_3: - self.skipTest("torch >= 2.3 is required") - set_seed(0) self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") self.model.eval()