Skip to content

Commit 31725aa

Browse files
committed
harmonize changes with huggingface/transformers#33122
1 parent abc8607 commit 31725aa

File tree

7 files changed

+193
-55
lines changed

7 files changed

+193
-55
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
deprecate,
4848
is_accelerate_available,
4949
is_bitsandbytes_available,
50+
is_bitsandbytes_version,
5051
is_torch_version,
5152
logging,
5253
)
@@ -976,27 +977,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
976977

977978
return model
978979

979-
# Taken from `transformers`.
980+
# Adapted from `transformers`.
980981
@wraps(torch.nn.Module.cuda)
981982
def cuda(self, *args, **kwargs):
982-
# Checks if the model has been loaded in 8-bit
983+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
983984
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
984-
raise ValueError(
985-
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
986-
" model has already been set to the correct devices and cast to the correct `dtype`."
987-
)
988-
else:
989-
return super().cuda(*args, **kwargs)
985+
if getattr(self, "is_loaded_in_8bit", False):
986+
raise ValueError(
987+
"Calling `cuda()` is not supported for `8-bit` quantized models. "
988+
" Please use the model as it is, since the model has already been set to the correct devices."
989+
)
990+
elif is_bitsandbytes_version("<", "0.43.2"):
991+
raise ValueError(
992+
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
993+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
994+
)
995+
return super().cuda(*args, **kwargs)
990996

991-
# Taken from `transformers`.
997+
# Adapted from `transformers`.
992998
@wraps(torch.nn.Module.to)
993999
def to(self, *args, **kwargs):
994-
# Checks if the model has been loaded in 8-bit
1000+
dtype_present_in_args = "dtype" in kwargs
1001+
1002+
if not dtype_present_in_args:
1003+
for arg in args:
1004+
if isinstance(arg, torch.dtype):
1005+
dtype_present_in_args = True
1006+
break
1007+
1008+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
9951009
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
996-
raise ValueError(
997-
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
998-
" model has already been set to the correct devices and cast to the correct `dtype`."
999-
)
1010+
if dtype_present_in_args:
1011+
raise ValueError(
1012+
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
1013+
" desired `dtype` by passing the correct `torch_dtype` argument."
1014+
)
1015+
1016+
if getattr(self, "is_loaded_in_8bit", False):
1017+
raise ValueError(
1018+
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
1019+
" model has already been set to the correct devices and casted to the correct `dtype`."
1020+
)
1021+
elif is_bitsandbytes_version("<", "0.43.2"):
1022+
raise ValueError(
1023+
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1024+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1025+
)
10001026
return super().to(*args, **kwargs)
10011027

10021028
# Taken from `transformers`.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
is_accelerate_version,
5757
is_torch_npu_available,
5858
is_torch_version,
59+
is_transformers_version,
5960
logging,
6061
numpy_to_pil,
6162
)
@@ -428,19 +429,23 @@ def module_is_offloaded(module):
428429
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
429430
for module in modules:
430431
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
431-
bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"}
432+
precision = None
433+
precision = "4bit" if is_loaded_in_4bit_bnb else "8bit"
432434

433435
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
434-
precision = bit_map[True]
435436
logger.warning(
436437
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision."
437438
)
438439

439-
if (is_loaded_in_4bit_bnb or is_loaded_in_4bit_bnb) and device is not None:
440-
precision = bit_map[True]
440+
if is_loaded_in_8bit_bnb and device is not None:
441441
logger.warning(
442442
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device."
443443
)
444+
445+
# This can happen for `transformer` models. CPU placement was added in
446+
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
447+
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
448+
module.to(device=device)
444449
else:
445450
module.to(device, dtype)
446451

@@ -449,6 +454,7 @@ def module_is_offloaded(module):
449454
and str(device) in ["cpu"]
450455
and not silence_dtype_warnings
451456
and not is_offloaded
457+
and not is_loaded_in_4bit_bnb
452458
):
453459
logger.warning(
454460
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
@@ -1023,16 +1029,13 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
10231029
if model is not None and isinstance(model, torch.nn.Module):
10241030
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model)
10251031

1026-
bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"}
1027-
10281032
if not isinstance(model, torch.nn.Module):
10291033
continue
10301034

10311035
# This is because the model would already be placed on a CUDA device.
1032-
if is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
1033-
precision = bit_map[True]
1036+
if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb:
10341037
logger.info(
1035-
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` {precision}."
1038+
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
10361039
)
10371040
continue
10381041

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
is_accelerate_available,
3333
is_accelerate_version,
3434
is_bitsandbytes_available,
35+
is_bitsandbytes_version,
3536
is_torch_available,
3637
logging,
3738
)
@@ -72,7 +73,7 @@ def validate_environment(self, *args, **kwargs):
7273
raise ImportError(
7374
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
7475
)
75-
if not is_bitsandbytes_available():
76+
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
7677
raise ImportError(
7778
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
7879
)
@@ -319,9 +320,18 @@ def is_trainable(self) -> bool:
319320
def _dequantize(self, model):
320321
from .utils import dequantize_and_replace
321322

323+
is_model_on_cpu = model.device.type == "cpu"
324+
if is_model_on_cpu:
325+
logger.info(
326+
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
327+
)
328+
model.to(torch.cuda.current_device())
329+
322330
model = dequantize_and_replace(
323331
model, self.modules_to_not_convert, quantization_config=self.quantization_config
324332
)
333+
if is_model_on_cpu:
334+
model.to("cpu")
325335
return model
326336

327337

@@ -348,17 +358,17 @@ def __init__(self, quantization_config, **kwargs):
348358
if self.quantization_config.llm_int8_skip_modules is not None:
349359
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
350360

351-
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4bit->8bit
361+
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
352362
def validate_environment(self, *args, **kwargs):
353363
if not torch.cuda.is_available():
354364
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
355365
if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"):
356366
raise ImportError(
357-
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
367+
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
358368
)
359-
if not is_bitsandbytes_available():
369+
if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"):
360370
raise ImportError(
361-
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
371+
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
362372
)
363373

364374
if kwargs.get("from_flax", False):

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
is_accelerate_available,
6363
is_accelerate_version,
6464
is_bitsandbytes_available,
65+
is_bitsandbytes_version,
6566
is_bs4_available,
6667
is_flax_available,
6768
is_ftfy_available,

src/diffusers/utils/import_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,20 @@ def is_peft_version(operation: str, version: str):
740740
return compare_versions(parse(_peft_version), operation, version)
741741

742742

743+
def is_bitsandbytes_version(operation: str, version: str):
744+
"""
745+
Args:
746+
Compares the current bitsandbytes version to a given reference with an operation.
747+
operation (`str`):
748+
A string representation of an operator, such as `">"` or `"<="`
749+
version (`str`):
750+
A version string
751+
"""
752+
if not _bitsandbytes_version:
753+
return False
754+
return compare_versions(parse(_bitsandbytes_version), operation, version)
755+
756+
743757
def is_k_diffusion_version(operation: str, version: str):
744758
"""
745759
Args:

src/diffusers/utils/testing_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import importlib
3+
import importlib.metadata
34
import inspect
45
import io
56
import logging
@@ -404,6 +405,31 @@ def decorator(test_case):
404405
return decorator
405406

406407

408+
def require_bitsandbytes_version_greater(bnb_version):
409+
def decorator(test_case):
410+
correct_bnb_version = is_bitsandbytes_available() and version.parse(
411+
version.parse(importlib.metadata.version("bitsandbytes")).base_version
412+
) > version.parse(bnb_version)
413+
return unittest.skipUnless(
414+
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
415+
)(test_case)
416+
417+
return decorator
418+
419+
420+
def require_transformers_version_greater(transformers_version):
421+
def decorator(test_case):
422+
correct_transformers_version = is_transformers_available() and version.parse(
423+
version.parse(importlib.metadata.version("transformers")).base_version
424+
) > version.parse(transformers_version)
425+
return unittest.skipUnless(
426+
correct_transformers_version,
427+
f"test requires transformers backend with the version greater than {transformers_version}",
428+
)(test_case)
429+
430+
return decorator
431+
432+
407433
def deprecate_after_peft_backend(test_case):
408434
"""
409435
Decorator marking a test that will be skipped after PEFT backend

0 commit comments

Comments
 (0)