2020import os
2121import re
2222from collections import OrderedDict
23- from functools import partial
23+ from functools import partial , wraps
2424from pathlib import Path
2525from typing import Any , Callable , List , Optional , Tuple , Union
2626
3131from torch import Tensor , nn
3232
3333from .. import __version__
34+ from ..quantizers import DiffusersAutoQuantizer
35+ from ..quantizers .quantization_config import QuantizationMethod
3436from ..utils import (
3537 CONFIG_NAME ,
3638 FLAX_WEIGHTS_NAME ,
@@ -128,6 +130,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128130 _supports_gradient_checkpointing = False
129131 _keys_to_ignore_on_load_unexpected = None
130132 _no_split_modules = None
133+ _keep_in_fp32_modules = []
131134
132135 def __init__ (self ):
133136 super ().__init__ ()
@@ -407,6 +410,18 @@ def save_pretrained(
407410 create_pr = create_pr ,
408411 )
409412
413+ def dequantize (self ):
414+ """
415+ Potentially dequantize the model in case it has been quantized by a quantization method that support
416+ dequantization.
417+ """
418+ hf_quantizer = getattr (self , "hf_quantizer" , None )
419+
420+ if hf_quantizer is None :
421+ raise ValueError ("You need to first quantize your model in order to dequantize it" )
422+
423+ return hf_quantizer .dequantize (self )
424+
410425 @classmethod
411426 @validate_hf_hub_args
412427 def from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
@@ -625,8 +640,42 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
625640 ** kwargs ,
626641 )
627642
628- # determine quantization config.
629- ##############################
643+ # determine initial quantization config.
644+ ###############################
645+ pre_quantized = getattr (config , "quantization_config" , None ) is not None
646+ if pre_quantized or quantization_config is not None :
647+ if pre_quantized :
648+ config .quantization_config = DiffusersAutoQuantizer .merge_quantization_configs (
649+ config .quantization_config , quantization_config
650+ )
651+ else :
652+ config .quantization_config = quantization_config
653+ hf_quantizer = DiffusersAutoQuantizer .from_config (config .quantization_config , pre_quantized = pre_quantized )
654+ else :
655+ hf_quantizer = None
656+
657+ if hf_quantizer is not None :
658+ hf_quantizer .validate_environment (torch_dtype = torch_dtype , from_flax = from_flax , device_map = device_map )
659+ torch_dtype = hf_quantizer .update_torch_dtype (torch_dtype )
660+ device_map = hf_quantizer .update_device_map (device_map )
661+
662+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
663+ user_agent ["quant" ] = hf_quantizer .quantization_config .quant_method .value
664+
665+ # Force-set to `True` for more mem efficiency
666+ if low_cpu_mem_usage is None :
667+ low_cpu_mem_usage = True
668+ logger .warning ("`low_cpu_mem_usage` was None, now set to True since model is quantized." )
669+
670+ # Check if `_keep_in_fp32_modules` is not None
671+ use_keep_in_fp32_modules = (cls ._keep_in_fp32_modules is not None ) and (
672+ (torch_dtype == torch .float16 ) or hasattr (hf_quantizer , "use_keep_in_fp32_modules" )
673+ )
674+ if use_keep_in_fp32_modules :
675+ keep_in_fp32_modules = cls ._keep_in_fp32_modules
676+ else :
677+ keep_in_fp32_modules = []
678+ ###############################
630679
631680 # Determine if we're loading from a directory of sharded checkpoints.
632681 is_sharded = False
@@ -733,6 +782,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
733782 with accelerate .init_empty_weights ():
734783 model = cls .from_config (config , ** unused_kwargs )
735784
785+ if hf_quantizer is not None :
786+ hf_quantizer .preprocess_model (
787+ model = model , device_map = device_map , keep_in_fp32_modules = keep_in_fp32_modules
788+ )
789+
790+ # We store the original dtype for quantized models as we cannot easily retrieve it
791+ # once the weights have been quantized
792+ # Note that once you have loaded a quantized model, you can't change its dtype so this will
793+ # remain a single source of truth
794+ config ._pre_quantization_dtype = torch_dtype
795+
736796 # if device_map is None, load the state dict and move the params from meta device to the cpu
737797 if device_map is None and not is_sharded :
738798 param_device = "cpu"
@@ -754,6 +814,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
754814 device = param_device ,
755815 dtype = torch_dtype ,
756816 model_name_or_path = pretrained_model_name_or_path ,
817+ hf_quantizer = hf_quantizer ,
818+ keep_in_fp32_modules = keep_in_fp32_modules ,
757819 )
758820
759821 if cls ._keys_to_ignore_on_load_unexpected is not None :
@@ -769,7 +831,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
769831 # Load weights and dispatch according to the device_map
770832 # by default the device_map is None and the weights are loaded on the CPU
771833 force_hook = True
772- device_map = _determine_device_map (model , device_map , max_memory , torch_dtype )
834+ device_map = _determine_device_map (
835+ model , device_map , max_memory , torch_dtype , keep_in_fp32_modules , hf_quantizer
836+ )
773837 if device_map is None and is_sharded :
774838 # we load the parameters on the cpu
775839 device_map = {"" : "cpu" }
@@ -863,6 +927,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
863927
864928 return model
865929
930+ @wraps (torch .nn .Module .cuda )
931+ def cuda (self , * args , ** kwargs ):
932+ # Checks if the model has been loaded in 8-bit
933+ if getattr (self , "quantization_method" , None ) == QuantizationMethod .BITS_AND_BYTES :
934+ raise ValueError (
935+ "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
936+ " model has already been set to the correct devices and casted to the correct `dtype`."
937+ )
938+ else :
939+ return super ().cuda (* args , ** kwargs )
940+
941+ @wraps (torch .nn .Module .to )
942+ def to (self , * args , ** kwargs ):
943+ # Checks if the model has been loaded in 8-bit
944+ if getattr (self , "quantization_method" , None ) == QuantizationMethod .BITS_AND_BYTES :
945+ raise ValueError (
946+ "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
947+ " model has already been set to the correct devices and casted to the correct `dtype`."
948+ )
949+ return super ().to (* args , ** kwargs )
950+
951+ def half (self , * args ):
952+ # Checks if the model is quantized
953+ if getattr (self , "is_quantized" , False ):
954+ raise ValueError (
955+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
956+ " model has already been casted to the correct `dtype`."
957+ )
958+ else :
959+ return super ().half (* args )
960+
961+ def float (self , * args ):
962+ # Checks if the model is quantized
963+ if getattr (self , "is_quantized" , False ):
964+ raise ValueError (
965+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
966+ " model has already been casted to the correct `dtype`."
967+ )
968+ else :
969+ return super ().float (* args )
970+
866971 @classmethod
867972 def _load_pretrained_model (
868973 cls ,
0 commit comments