diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c0a8c80f0b09..064ff69554df 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -264,27 +264,7 @@ def is_torch_bf16_gpu_available(): return False import torch - - # since currently no utility function is available we build our own. - # some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51 - # with additional check for torch version - # to succeed: (torch is required to be >= 1.10 anyway) - # 1. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU) - # 2. if using gpu, CUDA >= 11 - # 3. torch.autocast exists - # XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's - # really only correct for the 0th gpu (or currently set default device if different from 0) - if torch.cuda.is_available() and torch.version.cuda is not None: - if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: - return False - if int(torch.version.cuda.split(".")[0]) < 11: - return False - if not hasattr(torch.cuda.amp, "autocast"): - return False - else: - return False - - return True + return torch.cuda.is_bf16_supported() def is_torch_bf16_cpu_available():