|  | 
| 47 | 47 |     deprecate, | 
| 48 | 48 |     is_accelerate_available, | 
| 49 | 49 |     is_bitsandbytes_available, | 
|  | 50 | +    is_bitsandbytes_version, | 
| 50 | 51 |     is_torch_version, | 
| 51 | 52 |     logging, | 
| 52 | 53 | ) | 
| @@ -976,27 +977,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | 
| 976 | 977 | 
 | 
| 977 | 978 |         return model | 
| 978 | 979 | 
 | 
| 979 |  | -    # Taken from `transformers`. | 
|  | 980 | +    # Adapted from `transformers`. | 
| 980 | 981 |     @wraps(torch.nn.Module.cuda) | 
| 981 | 982 |     def cuda(self, *args, **kwargs): | 
| 982 |  | -        # Checks if the model has been loaded in 8-bit | 
|  | 983 | +        # Checks if the model has been loaded in 4-bit or 8-bit with BNB | 
| 983 | 984 |         if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: | 
| 984 |  | -            raise ValueError( | 
| 985 |  | -                "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" | 
| 986 |  | -                " model has already been set to the correct devices and cast to the correct `dtype`." | 
| 987 |  | -            ) | 
| 988 |  | -        else: | 
| 989 |  | -            return super().cuda(*args, **kwargs) | 
|  | 985 | +            if getattr(self, "is_loaded_in_8bit", False): | 
|  | 986 | +                raise ValueError( | 
|  | 987 | +                    "Calling `cuda()` is not supported for `8-bit` quantized models. " | 
|  | 988 | +                    " Please use the model as it is, since the model has already been set to the correct devices." | 
|  | 989 | +                ) | 
|  | 990 | +            elif is_bitsandbytes_version("<", "0.43.2"): | 
|  | 991 | +                raise ValueError( | 
|  | 992 | +                    "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " | 
|  | 993 | +                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." | 
|  | 994 | +                ) | 
|  | 995 | +        return super().cuda(*args, **kwargs) | 
| 990 | 996 | 
 | 
| 991 |  | -    # Taken from `transformers`. | 
|  | 997 | +    # Adapted from `transformers`. | 
| 992 | 998 |     @wraps(torch.nn.Module.to) | 
| 993 | 999 |     def to(self, *args, **kwargs): | 
| 994 |  | -        # Checks if the model has been loaded in 8-bit | 
|  | 1000 | +        dtype_present_in_args = "dtype" in kwargs | 
|  | 1001 | + | 
|  | 1002 | +        if not dtype_present_in_args: | 
|  | 1003 | +            for arg in args: | 
|  | 1004 | +                if isinstance(arg, torch.dtype): | 
|  | 1005 | +                    dtype_present_in_args = True | 
|  | 1006 | +                    break | 
|  | 1007 | + | 
|  | 1008 | +        # Checks if the model has been loaded in 4-bit or 8-bit with BNB | 
| 995 | 1009 |         if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: | 
| 996 |  | -            raise ValueError( | 
| 997 |  | -                "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" | 
| 998 |  | -                " model has already been set to the correct devices and cast to the correct `dtype`." | 
| 999 |  | -            ) | 
|  | 1010 | +            if dtype_present_in_args: | 
|  | 1011 | +                raise ValueError( | 
|  | 1012 | +                    "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" | 
|  | 1013 | +                    " desired `dtype` by passing the correct `torch_dtype` argument." | 
|  | 1014 | +                ) | 
|  | 1015 | + | 
|  | 1016 | +            if getattr(self, "is_loaded_in_8bit", False): | 
|  | 1017 | +                raise ValueError( | 
|  | 1018 | +                    "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" | 
|  | 1019 | +                    " model has already been set to the correct devices and casted to the correct `dtype`." | 
|  | 1020 | +                ) | 
|  | 1021 | +            elif is_bitsandbytes_version("<", "0.43.2"): | 
|  | 1022 | +                raise ValueError( | 
|  | 1023 | +                    "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " | 
|  | 1024 | +                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." | 
|  | 1025 | +                ) | 
| 1000 | 1026 |         return super().to(*args, **kwargs) | 
| 1001 | 1027 | 
 | 
| 1002 | 1028 |     # Taken from `transformers`. | 
|  | 
0 commit comments