Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)


def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
def set_module_quantized_tensor_to_device(
module, tensor_name, device, value=None, fp16_statistics=None, quantized_stats=None
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
Expand All @@ -39,6 +41,9 @@ class `Int8Params` from `bitsandbytes`.
The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for serialization.
quantized_stats (`dict[str, Any]`, *optional*):
Dict with items for 4-bit quantization

"""
# Recurse if needed
if "." in tensor_name:
Expand All @@ -58,8 +63,8 @@ class `Int8Params` from `bitsandbytes`.
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

is_4bit = False
is_8bit = False
prequantized_loading = fp16_statistics is not None or quantized_stats is not None

if is_buffer or not is_bitsandbytes_available():
is_8bit = False
is_4bit = False
Expand All @@ -74,32 +79,55 @@ class `Int8Params` from `bitsandbytes`.
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
else:
new_value = torch.tensor(value, device="cpu")

# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
new_value = new_value.T

kwargs = old_value.__dict__

if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)

if is_8bit:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
setattr(new_value, "SCB", fp16_statistics.to(device))

elif is_4bit:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.41"
)
# TODO update version number after BNB release with PR #753
if not is_4bit_serializable:
raise ValueError(
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value, quantized_stats=quantized_stats, requires_grad=False, device=device, **kwargs
)
else:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)

module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))

else:
raise ValueError("Quantized parameter passed with device.type == 'cuda'")

else:
if value is None:
Expand Down
108 changes: 73 additions & 35 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def _load_state_dict_into_meta_model(
is_quantized=False,
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -734,17 +735,40 @@ def _load_state_dict_into_meta_model(
elif not is_quantized:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
# handling newly quantized weights and loaded quantized weights
# edit the param.dtype restrictions and is_quantized condition when adding new quant methods
quantized_stats = {}

if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
):
# 4bit loading. Collecting components for restoring quantized weight
# This can be expanded to make a universal call for any quantized weight loading
for k, v in state_dict.items():
if param_name + "." in k:
quantized_stats[k] = v
unexpected_keys.remove(k)

if "SCB" not in param_name:
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
model, param_name, param_device, value=param, quantized_stats=quantized_stats
)

elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
# 8bit loading. Could be combined with the above 4bit call.
# condition looks unreliable
fp16_statistics_key = param_name.replace("weight", "SCB")
unexpected_keys.remove(fp16_statistics_key)

if "SCB" not in param_name:
# looks like a redundant if -- .SCB should not be in the loaded_state_dict_keys or expected_keys
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=state_dict[fp16_statistics_key]
)
else:
# loading not quantized params in quantized model
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)

return error_msgs, offload_index, state_dict_index


Expand Down Expand Up @@ -1822,16 +1846,25 @@ def save_pretrained(
kwargs["token"] = token

# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False):
warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
UserWarning,
if (
getattr(self, "is_loaded_in_8bit", False)
and getattr(self, "is_8bit_serializable", False)
and version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2")
):
raise NotImplementedError(
"You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
)

if getattr(self, "is_loaded_in_4bit", False):
# TODO: update bnb version in the statement. 0.41 is a temporary value to enable testing
if (
getattr(self, "is_loaded_in_4bit", False)
and getattr(self, "is_4bit_serializable", False)
and version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.41.0")
):
raise NotImplementedError(
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
"You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.42` installed."
)

if "save_config" in kwargs:
Expand Down Expand Up @@ -2362,8 +2395,11 @@ def from_pretrained(
use_safetensors = False

if is_bitsandbytes_available():
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41")
# TODO update version number after BNB release with PR #753
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
else:
is_4bit_serializable = False
is_8bit_serializable = False

if trust_remote_code is True:
Expand Down Expand Up @@ -2590,10 +2626,8 @@ def from_pretrained(

quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())

if (
is_8bit_serializable
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
and load_in_8bit
if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
(is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
):
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
logger.warning(
Expand All @@ -2603,8 +2637,8 @@ def from_pretrained(
)
config.quantization_config = quantization_config
elif (
is_8bit_serializable
and not load_in_8bit
(is_8bit_serializable or is_4bit_serializable)
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
quantization_config = config.quantization_config
Expand All @@ -2619,8 +2653,9 @@ def from_pretrained(
)

load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit

if load_in_8bit:
if load_in_8bit or load_in_4bit:
if torch_dtype is None:
torch_dtype = torch.float16
if device_map is None:
Expand All @@ -2638,12 +2673,12 @@ def from_pretrained(

elif (
not is_8bit_serializable
and not load_in_8bit
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
" `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
" `pip install --upgrade bitsandbytes`."
)

Expand Down Expand Up @@ -3017,6 +3052,7 @@ def from_pretrained(

model.config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable
model.is_4bit_serializable = is_4bit_serializable

if load_in_8bit and torch_dtype is None:
logger.warning(
Expand Down Expand Up @@ -3475,14 +3511,19 @@ def _find_mismatched_keys(
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])

if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
if model_key in model_state_dict:
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# Such mismatched weights are OK for 4bit quantizations.
# Need more reliable condition here, ideally based on type(module)
pass
elif state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys

if resolved_archive_file is not None:
Expand Down Expand Up @@ -3580,6 +3621,7 @@ def _find_mismatched_keys(
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
error_msgs += new_error_msgs
else:
Expand Down Expand Up @@ -3628,10 +3670,6 @@ def _find_mismatched_keys(
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

if is_quantized:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]

if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
Expand Down
Loading