Skip to content

Commit ac6d487

Browse files
committed
Update torch minimum version to 2.3
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent b3bd815 commit ac6d487

File tree

5 files changed

+6
-18
lines changed

5 files changed

+6
-18
lines changed

src/transformers/integrations/executorch.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3333
from ..pytorch_utils import (
3434
is_torch_greater_or_equal,
35-
is_torch_greater_or_equal_than_2_3,
3635
is_torch_greater_or_equal_than_2_6,
3736
)
3837

@@ -764,8 +763,6 @@ def convert_and_export_with_cache(
764763
Returns:
765764
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
766765
"""
767-
if not is_torch_greater_or_equal_than_2_3:
768-
raise ImportError("torch >= 2.3 is required.")
769766

770767
import torch.export._trace
771768

@@ -1034,8 +1031,6 @@ def export_with_dynamic_cache(
10341031
Returns:
10351032
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
10361033
"""
1037-
if not is_torch_greater_or_equal_than_2_3:
1038-
raise ImportError("torch >= 2.3 is required.")
10391034

10401035
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
10411036
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)

src/transformers/modeling_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -450,23 +450,20 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
450450
"U8": torch.uint8,
451451
"I8": torch.int8,
452452
"I16": torch.int16,
453+
"U16": torch.uint16,
453454
"F16": torch.float16,
454455
"BF16": torch.bfloat16,
455456
"I32": torch.int32,
457+
"U32": torch.uint32,
456458
"F32": torch.float32,
457459
"F64": torch.float64,
458460
"I64": torch.int64,
461+
"U64": torch.uint64,
459462
"F8_E4M3": torch.float8_e4m3fn,
460463
"F8_E5M2": torch.float8_e5m2,
461464
}
462465

463466

464-
if is_torch_greater_or_equal("2.3.0"):
465-
str_to_torch_dtype["U16"] = torch.uint16
466-
str_to_torch_dtype["U32"] = torch.uint32
467-
str_to_torch_dtype["U64"] = torch.uint64
468-
469-
470467
def load_state_dict(
471468
checkpoint_file: Union[str, os.PathLike],
472469
is_quantized: bool = False,

src/transformers/pytorch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
3838
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
3939
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
40-
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
4140

4241
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
42+
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
4343
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
4444
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
4545
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)

src/transformers/utils/import_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
236236

237237
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
238238
if _torch_available:
239-
_torch_available = version.parse(_torch_version) >= version.parse("2.2.0")
239+
_torch_available = version.parse(_torch_version) >= version.parse("2.3.0")
240240
if not _torch_available:
241-
logger.warning(f"Disabling PyTorch because PyTorch >= 2.2 is required but found {_torch_version}")
241+
logger.warning(f"Disabling PyTorch because PyTorch >= 2.3 is required but found {_torch_version}")
242242

243243

244244
_essentia_available = importlib.util.find_spec("essentia") is not None

tests/test_executorch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@
2323
TorchExportableModuleWithHybridCache,
2424
TorchExportableModuleWithStaticCache,
2525
)
26-
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
2726
from transformers.testing_utils import require_torch
2827

2928

3029
@require_torch
3130
class ExecutorchTest(unittest.TestCase):
3231
def setUp(self):
33-
if not is_torch_greater_or_equal_than_2_3:
34-
self.skipTest("torch >= 2.3 is required")
35-
3632
set_seed(0)
3733
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
3834
self.model.eval()

0 commit comments

Comments
 (0)