Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
159bda9
4bit refs and ifs in mu.PTM.from_pretrained()
poedator Sep 7, 2023
f5284e4
bnb.py calls from_prequantized() now
poedator Sep 7, 2023
a77cef4
modeling_utils 4bit support
poedator Sep 7, 2023
0f92002
bnb.py separate calls to Params4bit constructor
poedator Sep 7, 2023
96c510d
comment fix
poedator Sep 7, 2023
b7e7f5d
black applied
poedator Sep 7, 2023
0c7adba
fixed double .T for prequantized loading
poedator Sep 8, 2023
0cc7f08
added is_4bit_serializable
poedator Sep 12, 2023
22a996b
upd load 4b
poedator Sep 13, 2023
e83f463
fix for unexpected_keys when loading
poedator Sep 13, 2023
c1612c4
bnb interface streamlined
poedator Sep 13, 2023
7e116dd
bnb.py fixed
poedator Sep 14, 2023
47558a9
mismatch condition comment edited
poedator Sep 14, 2023
6a8dcd7
black
poedator Sep 14, 2023
30d4129
moved SCB setting in bnb.py
poedator Sep 14, 2023
7e94814
fstring fixes
poedator Sep 14, 2023
a1c0de6
added test for 4bit serialization
poedator Sep 14, 2023
0c0182d
quantized_stats descr in docstring intg/bnb
poedator Sep 22, 2023
a58d86f
simpler condition in 8bit check @intgr/bnb
poedator Sep 22, 2023
286ca6b
improved def _load_state_dict_into_meta_model(
poedator Sep 22, 2023
ca4db23
integr/bnb comments implemented & cleanup
poedator Sep 22, 2023
c95a4c0
NotImplementedError when bnb ver too low for 4/8 bit save
poedator Sep 22, 2023
5c64cc1
removed commented miss/unxp keys handler @3666
poedator Sep 22, 2023
d8e99d4
bnb version check reworked
poedator Sep 22, 2023
0b33bf7
[Mistral] Mistral-7B-v0.1 support
Bam4d Sep 27, 2023
4cca179
fixing names
Bam4d Sep 27, 2023
602428f
slightly longer test
Bam4d Sep 27, 2023
6e73cce
fixups
Bam4d Sep 27, 2023
0c755c8
not_doctested
Sep 27, 2023
59b0834
wrongly formatted references
Sep 27, 2023
b1c3d03
make fixuped
timlacroix Sep 27, 2023
d18fb6d
Merge pull request #1 from webpolis/save4
webpolis Sep 30, 2023
9a382d1
Merge pull request #2 from huggingface/main
webpolis Sep 30, 2023
ecc62d6
Merge branch 'add-mistral' of https://github.com/mistralai/transforme…
webpolis Sep 30, 2023
250f87e
Merge branch 'mistralai-add-mistral' into main
webpolis Sep 30, 2023
fd3bd61
Merge branch 'main' of https://github.com/huggingface/transformers in…
webpolis Oct 10, 2023
60e341f
Merge branch 'huggingface-main' into main
webpolis Oct 10, 2023
2ba93e9
Merge pull request #5 from huggingface/main
webpolis Oct 28, 2023
9276d20
Merge pull request #6 from huggingface/main
webpolis Nov 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)


def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
def set_module_quantized_tensor_to_device(
module, tensor_name, device, value=None, fp16_statistics=None, quantized_stats=None
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
Expand All @@ -39,6 +41,9 @@ class `Int8Params` from `bitsandbytes`.
The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for serialization.
quantized_stats (`dict[str, Any]`, *optional*):
Dict with items for 4-bit quantization

"""
# Recurse if needed
if "." in tensor_name:
Expand All @@ -58,8 +63,8 @@ class `Int8Params` from `bitsandbytes`.
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

is_4bit = False
is_8bit = False
prequantized_loading = fp16_statistics is not None or quantized_stats is not None

if is_buffer or not is_bitsandbytes_available():
is_8bit = False
is_4bit = False
Expand All @@ -74,32 +79,55 @@ class `Int8Params` from `bitsandbytes`.
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
else:
new_value = torch.tensor(value, device="cpu")

# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
new_value = new_value.T

kwargs = old_value.__dict__

if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)

if is_8bit:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
setattr(new_value, "SCB", fp16_statistics.to(device))

elif is_4bit:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.41"
)
# TODO update version number after BNB release with PR #753
if not is_4bit_serializable:
raise ValueError(
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value, quantized_stats=quantized_stats, requires_grad=False, device=device, **kwargs
)
else:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)

module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))

else:
raise ValueError("Quantized parameter passed with device.type == 'cuda'")

else:
if value is None:
Expand Down
106 changes: 70 additions & 36 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def _load_state_dict_into_meta_model(
is_quantized=False,
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -744,17 +745,40 @@ def _load_state_dict_into_meta_model(
elif not is_quantized:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
# handling newly quantized weights and loaded quantized weights
# edit the param.dtype restrictions and is_quantized condition when adding new quant methods
quantized_stats = {}

if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
):
# 4bit loading. Collecting components for restoring quantized weight
# This can be expanded to make a universal call for any quantized weight loading
for k, v in state_dict.items():
if param_name + "." in k:
quantized_stats[k] = v
unexpected_keys.remove(k)

if "SCB" not in param_name:
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
model, param_name, param_device, value=param, quantized_stats=quantized_stats
)

elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
# 8bit loading. Could be combined with the above 4bit call.
# condition looks unreliable
fp16_statistics_key = param_name.replace("weight", "SCB")
unexpected_keys.remove(fp16_statistics_key)

if "SCB" not in param_name:
# looks like a redundant if -- .SCB should not be in the loaded_state_dict_keys or expected_keys
set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=state_dict[fp16_statistics_key]
)
else:
# loading not quantized params in quantized model
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)

return error_msgs, offload_index, state_dict_index


Expand Down Expand Up @@ -2042,18 +2066,23 @@ def save_pretrained(
# Checks if the model has been loaded in 8-bit
if (
getattr(self, "is_loaded_in_8bit", False)
and not getattr(self, "is_8bit_serializable", False)
and not _hf_peft_config_loaded
and getattr(self, "is_8bit_serializable", False)
and version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2")
):
raise ValueError(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
raise NotImplementedError(
"You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
)

# If the model has adapters attached, you can save the adapters
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
# TODO: update bnb version in the statement. 0.41 is a temporary value to enable testing
if (
getattr(self, "is_loaded_in_4bit", False)
and getattr(self, "is_4bit_serializable", False)
and version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.41.0")
):
raise NotImplementedError(
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
"You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.42` installed."
)

if "save_config" in kwargs:
Expand Down Expand Up @@ -2619,8 +2648,11 @@ def from_pretrained(
use_safetensors = False

if is_bitsandbytes_available():
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41")
# TODO update version number after BNB release with PR #753
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
else:
is_4bit_serializable = False
is_8bit_serializable = False

if trust_remote_code is True:
Expand Down Expand Up @@ -2889,10 +2921,8 @@ def from_pretrained(
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True

if (
is_8bit_serializable
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
and load_in_8bit
if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
(is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
):
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
logger.warning(
Expand All @@ -2902,8 +2932,8 @@ def from_pretrained(
)
config.quantization_config = quantization_config
elif (
is_8bit_serializable
and not load_in_8bit
(is_8bit_serializable or is_4bit_serializable)
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
quantization_config = config.quantization_config
Expand All @@ -2918,8 +2948,9 @@ def from_pretrained(
)

load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit

if load_in_8bit:
if load_in_8bit or load_in_4bit:
if torch_dtype is None:
torch_dtype = torch.float16
if device_map is None:
Expand All @@ -2937,12 +2968,12 @@ def from_pretrained(

elif (
not is_8bit_serializable
and not load_in_8bit
and not (load_in_8bit or load_in_4bit)
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
" `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
" `pip install --upgrade bitsandbytes`."
)

Expand Down Expand Up @@ -3338,6 +3369,7 @@ def from_pretrained(

config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable
model.is_4bit_serializable = is_4bit_serializable

if load_in_8bit and torch_dtype is None:
logger.warning(
Expand Down Expand Up @@ -3818,14 +3850,19 @@ def _find_mismatched_keys(
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])

if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
if model_key in model_state_dict:
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# Such mismatched weights are OK for 4bit quantizations.
# Need more reliable condition here, ideally based on type(module)
pass
elif state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys

if resolved_archive_file is not None:
Expand Down Expand Up @@ -3934,6 +3971,7 @@ def _find_mismatched_keys(
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
error_msgs += new_error_msgs
else:
Expand Down Expand Up @@ -3971,10 +4009,6 @@ def _find_mismatched_keys(
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

if is_quantized:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]

if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
Expand Down
Loading