Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,19 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""
import bitsandbytes as bnb

# BNB requires CUDA weights
if state.SCB is None:
state.SCB = weight.SCB

# BNB requires accelerator weights
device = weight.device
is_cpu = device.type == torch.device("cpu").type
if is_cpu:
if torch.cuda.is_available():
weight = weight.to(torch.device("cuda"))
state.SCB = state.SCB.to(torch.device("cuda"))
elif is_xpu_available():
weight = weight.to(torch.device("xpu"))
state.SCB = state.SCB.to(torch.device("xpu"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight is moved to accelerator, but SCB not, will make later int8_vectorwise_dequant cry to say weight in xpu, but SCB in cpu and make CPU CI, fail, so move it to accelerator too.


cls_name = weight.__class__.__name__
if cls_name == "Params4bit":
Expand All @@ -109,9 +114,6 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
dequantized = dequantized.to(device)
return dequantized

if state.SCB is None:
state.SCB = weight.SCB

if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
Expand All @@ -121,6 +123,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):

if is_cpu:
dequantized = dequantized.to(device)
state.SCB = state.SCB.to(device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both should be move back

return dequantized


Expand Down
55 changes: 31 additions & 24 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,28 +2817,29 @@ def test_olora_with_quantized_model(self, bits):
@pytest.mark.skipif(
not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a hardware accelerator"
)
@pytest.mark.single_gpu_tests
@require_bitsandbytes
class TestLoftQ:
r"""
Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization.
"""

# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
# conservative value to prevent flakiness, in practice most gains are > 1.5
device = infer_device()
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
def get_error_factor(self, device):
# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
# conservative value to prevent flakiness, in practice most gains are > 1.5
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you find it necessary to reduce the factor to 1.005 for XPU and CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, it's needed. We are continuing enhance bnb support, when it works, we can remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I checked again with CUDA/CPU and the improvement is much larger than 1.03 (which, as indicated, is a very conservative value to avoid flakiness). So if XPU is below that, I think it's an indicator that something is missing.

return error_factor

def get_input(self, model_id, device):
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer("All I want is", padding=True, return_tensors="pt")
inputs = inputs.to(self.device)
inputs = inputs.to(device)
return inputs

def get_base_model(self, model_id, device, **kwargs):
cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM
model = cls.from_pretrained(model_id, **kwargs).eval()
model = model.to(self.device)
model = cls.from_pretrained(model_id, device_map=device, **kwargs).eval()
return model

def get_logits(self, model, inputs):
Expand Down Expand Up @@ -2882,7 +2883,7 @@ def get_errors(
raise ValueError("bits must be 4 or 8")

quantized_model = get_peft_model(
self.get_base_model(model_id, device=None, **kwargs),
self.get_base_model(model_id, device, **kwargs),
lora_config,
)
torch.manual_seed(0)
Expand All @@ -2901,10 +2902,10 @@ def get_errors(
)
model = self.get_base_model(model_id, device)
if device != "cpu":
model = model.to(torch_device)
model = model.to(device)
loftq_model = get_peft_model(model, lora_config)
if device != "cpu":
loftq_model = loftq_model.to(torch_device)
loftq_model = loftq_model.to(device)

# save LoRA weights, they should be initialized such that they minimize the quantization error
loftq_model.base_model.peft_config["default"].init_lora_weights = True
Expand All @@ -2917,7 +2918,7 @@ def get_errors(
clear_device_cache(garbage_collection=True)

# now load quantized model and apply LoftQ-initialized weights on top
base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32)
base_model = self.get_base_model(tmp_path / "base_model", device=device, **kwargs, torch_dtype=torch.float32)
loftq_model = PeftModel.from_pretrained(base_model, tmp_path / "loftq_model", is_trainable=True)

# TODO sanity check: model is quantized
Expand Down Expand Up @@ -2966,8 +2967,9 @@ def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mse_loftq < (mse_quantized / error_factor)
assert mae_loftq < (mae_quantized / error_factor)

@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_8bit(self, device, tmp_path):
Expand All @@ -2981,8 +2983,9 @@ def test_bloomz_loftq_8bit(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mse_loftq < (mse_quantized / error_factor)
assert mae_loftq < (mae_quantized / error_factor)

@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
Expand All @@ -2998,8 +3001,9 @@ def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mse_loftq < (mse_quantized / error_factor)
assert mae_loftq < (mae_quantized / error_factor)

@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_loftq_4bit(self, device, tmp_path):
Expand All @@ -3013,8 +3017,9 @@ def test_t5_loftq_4bit(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mse_loftq < (mse_quantized / error_factor)
assert mae_loftq < (mae_quantized / error_factor)

@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_loftq_8bit(self, device, tmp_path):
Expand All @@ -3028,8 +3033,9 @@ def test_t5_loftq_8bit(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mse_loftq < (mse_quantized / error_factor)
assert mae_loftq < (mae_quantized / error_factor)

@pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
@pytest.mark.parametrize("device", [torch_device, "cpu"])
Expand Down Expand Up @@ -3063,8 +3069,9 @@ def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
assert mse_loftq > 0.0

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
assert mae_loftq < (mae_quantized / self.error_factor)
assert mse_loftq < (mse_quantized / self.error_factor)
error_factor = self.get_error_factor(device)
assert mae_loftq < (mae_quantized / error_factor)
assert mse_loftq < (mse_quantized / error_factor)

def test_replace_lora_weights_with_loftq_using_callable(self):
"""
Expand Down
Loading