-
Notifications
You must be signed in to change notification settings - Fork 2.1k
TST: fix to issue for 8-bit model
#2797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,14 +93,19 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): | |
| """ | ||
| import bitsandbytes as bnb | ||
|
|
||
| # BNB requires CUDA weights | ||
| if state.SCB is None: | ||
| state.SCB = weight.SCB | ||
|
|
||
| # BNB requires accelerator weights | ||
| device = weight.device | ||
| is_cpu = device.type == torch.device("cpu").type | ||
| if is_cpu: | ||
| if torch.cuda.is_available(): | ||
| weight = weight.to(torch.device("cuda")) | ||
| state.SCB = state.SCB.to(torch.device("cuda")) | ||
| elif is_xpu_available(): | ||
| weight = weight.to(torch.device("xpu")) | ||
| state.SCB = state.SCB.to(torch.device("xpu")) | ||
|
|
||
| cls_name = weight.__class__.__name__ | ||
| if cls_name == "Params4bit": | ||
|
|
@@ -109,9 +114,6 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): | |
| dequantized = dequantized.to(device) | ||
| return dequantized | ||
|
|
||
| if state.SCB is None: | ||
| state.SCB = weight.SCB | ||
|
|
||
| if hasattr(bnb.functional, "int8_vectorwise_dequant"): | ||
| # Use bitsandbytes API if available (requires v0.45.0+) | ||
| dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) | ||
|
|
@@ -121,6 +123,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): | |
|
|
||
| if is_cpu: | ||
| dequantized = dequantized.to(device) | ||
| state.SCB = state.SCB.to(device) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both should be move back |
||
| return dequantized | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2817,28 +2817,29 @@ def test_olora_with_quantized_model(self, bits): | |
| @pytest.mark.skipif( | ||
| not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a hardware accelerator" | ||
| ) | ||
| @pytest.mark.single_gpu_tests | ||
| @require_bitsandbytes | ||
| class TestLoftQ: | ||
| r""" | ||
| Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization. | ||
| """ | ||
|
|
||
| # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to | ||
| # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very | ||
| # conservative value to prevent flakiness, in practice most gains are > 1.5 | ||
| device = infer_device() | ||
| error_factor = 1.005 if device in ("xpu", "cpu") else 1.03 | ||
| def get_error_factor(self, device): | ||
| # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to | ||
| # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very | ||
| # conservative value to prevent flakiness, in practice most gains are > 1.5 | ||
| error_factor = 1.005 if device in ("xpu", "cpu") else 1.03 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you find it necessary to reduce the factor to 1.005 for XPU and CPU?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently, it's needed. We are continuing enhance bnb support, when it works, we can remove it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I checked again with CUDA/CPU and the improvement is much larger than 1.03 (which, as indicated, is a very conservative value to avoid flakiness). So if XPU is below that, I think it's an indicator that something is missing. |
||
| return error_factor | ||
|
|
||
| def get_input(self, model_id, device): | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| inputs = tokenizer("All I want is", padding=True, return_tensors="pt") | ||
| inputs = inputs.to(self.device) | ||
| inputs = inputs.to(device) | ||
| return inputs | ||
|
|
||
| def get_base_model(self, model_id, device, **kwargs): | ||
| cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM | ||
| model = cls.from_pretrained(model_id, **kwargs).eval() | ||
| model = model.to(self.device) | ||
| model = cls.from_pretrained(model_id, device_map=device, **kwargs).eval() | ||
| return model | ||
|
|
||
| def get_logits(self, model, inputs): | ||
|
|
@@ -2882,7 +2883,7 @@ def get_errors( | |
| raise ValueError("bits must be 4 or 8") | ||
|
|
||
| quantized_model = get_peft_model( | ||
| self.get_base_model(model_id, device=None, **kwargs), | ||
| self.get_base_model(model_id, device, **kwargs), | ||
| lora_config, | ||
| ) | ||
| torch.manual_seed(0) | ||
|
|
@@ -2901,10 +2902,10 @@ def get_errors( | |
| ) | ||
| model = self.get_base_model(model_id, device) | ||
| if device != "cpu": | ||
| model = model.to(torch_device) | ||
| model = model.to(device) | ||
| loftq_model = get_peft_model(model, lora_config) | ||
| if device != "cpu": | ||
| loftq_model = loftq_model.to(torch_device) | ||
| loftq_model = loftq_model.to(device) | ||
|
|
||
| # save LoRA weights, they should be initialized such that they minimize the quantization error | ||
| loftq_model.base_model.peft_config["default"].init_lora_weights = True | ||
|
|
@@ -2917,7 +2918,7 @@ def get_errors( | |
| clear_device_cache(garbage_collection=True) | ||
|
|
||
| # now load quantized model and apply LoftQ-initialized weights on top | ||
| base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32) | ||
| base_model = self.get_base_model(tmp_path / "base_model", device=device, **kwargs, torch_dtype=torch.float32) | ||
| loftq_model = PeftModel.from_pretrained(base_model, tmp_path / "loftq_model", is_trainable=True) | ||
|
|
||
| # TODO sanity check: model is quantized | ||
|
|
@@ -2966,8 +2967,9 @@ def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
|
|
||
| @pytest.mark.parametrize("device", [torch_device, "cpu"]) | ||
| def test_bloomz_loftq_8bit(self, device, tmp_path): | ||
|
|
@@ -2981,8 +2983,9 @@ def test_bloomz_loftq_8bit(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
|
|
||
| @pytest.mark.parametrize("device", [torch_device, "cpu"]) | ||
| def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path): | ||
|
|
@@ -2998,8 +3001,9 @@ def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
|
|
||
| @pytest.mark.parametrize("device", [torch_device, "cpu"]) | ||
| def test_t5_loftq_4bit(self, device, tmp_path): | ||
|
|
@@ -3013,8 +3017,9 @@ def test_t5_loftq_4bit(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
|
|
||
| @pytest.mark.parametrize("device", [torch_device, "cpu"]) | ||
| def test_t5_loftq_8bit(self, device, tmp_path): | ||
|
|
@@ -3028,8 +3033,9 @@ def test_t5_loftq_8bit(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
|
|
||
| @pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good | ||
| @pytest.mark.parametrize("device", [torch_device, "cpu"]) | ||
|
|
@@ -3063,8 +3069,9 @@ def test_bloomz_loftq_8bit_dora(self, device, tmp_path): | |
| assert mse_loftq > 0.0 | ||
|
|
||
| # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin | ||
| assert mae_loftq < (mae_quantized / self.error_factor) | ||
| assert mse_loftq < (mse_quantized / self.error_factor) | ||
| error_factor = self.get_error_factor(device) | ||
| assert mae_loftq < (mae_quantized / error_factor) | ||
| assert mse_loftq < (mse_quantized / error_factor) | ||
|
|
||
| def test_replace_lora_weights_with_loftq_using_callable(self): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weightis moved to accelerator, butSCBnot, will make laterint8_vectorwise_dequantcry to sayweight in xpu, but SCB in cpuand make CPU CI, fail, so move it to accelerator too.