From 20b936a4d873e7990d036e4971eb96187a4461da Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:42:06 -0400 Subject: [PATCH] Linear8bitLt: support device movement after forward() --- bitsandbytes/nn/modules.py | 37 ++++++++++++++++++++++++++++++------- tests/test_linear8bitlt.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1adf75e79..69d39277b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -679,19 +679,27 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type != "meta" and self.data.device.type == "cpu": - if device.type != "cpu" or self.data.dtype != torch.int8: - return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type == "cpu": - self.CB = self.data + is_quantized = self.data.dtype == torch.int8 + if not is_quantized and device is not None and device.type != "meta" and self.data.device.type == "cpu": + # We're moving from a CPU device to a non-meta device. + # In this circumstance, we want to quantize if we haven't already. + return self._quantize(device) + + # Create a new parameter on the target device. new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) - new_param.CB = self.CB - new_param.SCB = self.SCB + + # If we had already quantized, move the statistics appropriately. + if is_quantized and device is not None: + if self.CB is not None: + new_param.CB = new_param.data + + if self.SCB is not None: + new_param.SCB = self.SCB.to(device) return new_param @@ -1037,6 +1045,21 @@ def init_8bit_state(self): self.weight.CB = None self.weight.SCB = None + def to(self, *args, **kwargs): + # Call the parent to() method to handle standard parameter/buffer movement + result = super().to(*args, **kwargs) + + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + # Handle state tensors if needed. + if device is not None: + if result.state.CB is not None: + result.state.CB = result.state.CB.to(device) + if result.state.SCB is not None: + result.state.SCB = result.state.SCB.to(device) + + return result + def forward(self, x: torch.Tensor): self.state.is_training = self.training if self.weight.CB is not None: diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 51b4cf9cd..6da3c28f8 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): grad_compiled = x.grad.clone() torch.testing.assert_close(grad_compiled, grad_ref) + + +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") +def test_linear8bitlt_device_movement(device): + """Test moving a Linear8bitLt layer between CPU and an accelerator device.""" + + # Create a Linear8bitLt layer on CPU + layer = bnb.nn.Linear8bitLt(32, 128, bias=False, has_fp16_weights=False) + torch.nn.init.xavier_uniform_(layer.weight) + + # Create a sample input. + x = torch.randn(4, 32, dtype=torch.float16, device="cpu") + + # Move to the device. This should quantize the weights. + layer = layer.to(device) + assert layer.weight.data.dtype == torch.int8 + + # Call the layer on the accelerator device. + out_accelerator = layer(x.to(device)) + + # Move back to CPU and call again. + layer = layer.to("cpu") + out_cpu = layer(x) + + # Move back to the accelerator device and call again. + layer = layer.to(device) + out_accelerator_2 = layer(x.to(device)) + + # Move back to the CPU and call one last time. + layer = layer.to("cpu") + out_cpu_2 = layer(x) + + # CPU outputs should match both times. + torch.testing.assert_close(out_cpu_2, out_cpu, rtol=1e-8, atol=1e-8) + + # Accelerator outputs should match both times. + torch.testing.assert_close(out_accelerator_2, out_accelerator, rtol=1e-8, atol=1e-8)