Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
is_torch_greater_or_equal_than_2_6,
)

Expand Down Expand Up @@ -765,8 +764,6 @@ def convert_and_export_with_cache(
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")

import torch.export._trace

Expand Down Expand Up @@ -1035,8 +1032,6 @@ def export_with_dynamic_cache(
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")

# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
Expand Down
9 changes: 3 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,23 +372,20 @@ def get_state_dict_dtype(state_dict):
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"U16": torch.uint16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"U32": torch.uint32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"U64": torch.uint64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}


if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U16"] = torch.uint16
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64


def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)

# For backwards compatibility (e.g. some remote codes on Hub using those variables).
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
@lru_cache
def is_torch_available() -> bool:
is_available, torch_version = _is_package_available("torch", return_version=True)
if is_available and version.parse(torch_version) < version.parse("2.2.0"):
logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.2 is required but found {torch_version}")
return is_available and version.parse(torch_version) >= version.parse("2.2.0")
if is_available and version.parse(torch_version) < version.parse("2.3.0"):
logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.3 is required but found {torch_version}")
return is_available and version.parse(torch_version) >= version.parse("2.3.0")


@lru_cache
Expand Down
4 changes: 0 additions & 4 deletions tests/test_executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,12 @@
TorchExportableModuleWithHybridCache,
TorchExportableModuleWithStaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.testing_utils import require_torch


@require_torch
class ExecutorchTest(unittest.TestCase):
def setUp(self):
if not is_torch_greater_or_equal_than_2_3:
self.skipTest("torch >= 2.3 is required")

set_seed(0)
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.model.eval()
Expand Down