|
22 | 22 | from pathlib import Path |
23 | 23 | from typing import Any, Dict, List, Optional, Union |
24 | 24 |
|
| 25 | +from packaging import version |
| 26 | + |
25 | 27 | from .debug_utils import DebugOption |
26 | 28 | from .trainer_utils import ( |
27 | 29 | EvaluationStrategy, |
@@ -478,6 +480,8 @@ class TrainingArguments: |
478 | 480 | are also available. See the [Ray documentation]( |
479 | 481 | https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for |
480 | 482 | more options. |
| 483 | + use_mps_device (`bool`, *optional*, defaults to `False`): |
| 484 | + Whether to use Apple Silicon chip based `mps` device. |
481 | 485 | """ |
482 | 486 |
|
483 | 487 | output_dir: str = field( |
@@ -630,6 +634,9 @@ class TrainingArguments: |
630 | 634 | }, |
631 | 635 | ) |
632 | 636 | no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) |
| 637 | + use_mps_device: bool = field( |
| 638 | + default=False, metadata={"help": "Whether to use Apple Silicon chip based `mps` device."} |
| 639 | + ) |
633 | 640 | seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) |
634 | 641 | data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) |
635 | 642 | jit_mode_eval: bool = field( |
@@ -1368,16 +1375,42 @@ def _setup_devices(self) -> "torch.device": |
1368 | 1375 | device = torch.device("cuda", self.local_rank) |
1369 | 1376 | self._n_gpu = 1 |
1370 | 1377 | elif self.local_rank == -1: |
1371 | | - # if n_gpu is > 1 we'll use nn.DataParallel. |
1372 | | - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` |
1373 | | - # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will |
1374 | | - # trigger an error that a device index is missing. Index 0 takes into account the |
1375 | | - # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` |
1376 | | - # will use the first GPU in that env, i.e. GPU#1 |
1377 | | - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
1378 | | - # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at |
1379 | | - # the default value. |
1380 | | - self._n_gpu = torch.cuda.device_count() |
| 1378 | + if self.use_mps_device: |
| 1379 | + if not torch.backends.mps.is_available(): |
| 1380 | + if not torch.backends.mps.is_built(): |
| 1381 | + raise AssertionError( |
| 1382 | + "MPS not available because the current PyTorch install was not " |
| 1383 | + "built with MPS enabled. Please install torch version >=1.12.0 on " |
| 1384 | + "your Apple silicon Mac running macOS 12.3 or later with a native " |
| 1385 | + "version (arm64) of Python" |
| 1386 | + ) |
| 1387 | + else: |
| 1388 | + raise AssertionError( |
| 1389 | + "MPS not available because the current MacOS version is not 12.3+ " |
| 1390 | + "and/or you do not have an MPS-enabled device on this machine." |
| 1391 | + ) |
| 1392 | + else: |
| 1393 | + if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"): |
| 1394 | + warnings.warn( |
| 1395 | + "We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)" |
| 1396 | + " on your MacOS machine. It has major fixes related to model correctness and performance" |
| 1397 | + " improvements for transformer based models. Please refer to" |
| 1398 | + " https://github.com/pytorch/pytorch/issues/82707 for more details." |
| 1399 | + ) |
| 1400 | + device = torch.device("mps") |
| 1401 | + self._n_gpu = 1 |
| 1402 | + |
| 1403 | + else: |
| 1404 | + # if n_gpu is > 1 we'll use nn.DataParallel. |
| 1405 | + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` |
| 1406 | + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will |
| 1407 | + # trigger an error that a device index is missing. Index 0 takes into account the |
| 1408 | + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` |
| 1409 | + # will use the first GPU in that env, i.e. GPU#1 |
| 1410 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 1411 | + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at |
| 1412 | + # the default value. |
| 1413 | + self._n_gpu = torch.cuda.device_count() |
1381 | 1414 | else: |
1382 | 1415 | # Here, we'll use torch.distributed. |
1383 | 1416 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs |
|
0 commit comments