Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,8 @@ def is_xpu_available(check_device=False):
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


@lru_cache
def is_diffusers_available():
return importlib.util.find_spec("diffusers") is not None
32 changes: 24 additions & 8 deletions src/peft/utils/hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def prepare_model_for_compiled_hotswap(
# do inference with adapter 1
```
"""
is_compiled = hasattr(model, "_orig_mod")
is_compiled = hasattr(model, "_orig_mod") or getattr(model, "_compiled_call_impl", False)
if is_compiled:
raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model")

Expand Down Expand Up @@ -416,18 +416,34 @@ def hotswap_adapter_from_state_dict(
# swap actual weights
# no need to account for potential _orig_mod in key here, as torch handles that
old_val = attrgetter(key)(model)
new_val = new_val.to(old_val.data.device)

# We try to detect if the model is compiled but it does not always work, e.g. if hotswapping is called from
# within the model itself. In this case, swap_tensors raises RuntimeError and should continue without
# swap_tensors.
if not is_compiled and not is_compiled_inplace:
torch.utils.swap_tensors(old_val, new_val)
continue
try:
torch.utils.swap_tensors(old_val, new_val)
continue
except RuntimeError:
is_compiled = True

# Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if
# this workaround could not cause trouble but the tests indicate that it works.
if old_val.shape == new_val.shape:
# either
# - adapters had the same rank
# - adapters were padded with prepare_model_for_compiled_hotswap and 2nd adapter was larger
old_val.data = new_val.data
else:
if old_val.dim() != 2:
# TODO conv2d
raise NotImplementedError
# if 2nd adapter was smaller, ensure to fill up to adapter dimension and set the rest to zeros
if old_val.dim() not in (2, 4):
raise NotImplementedError(
f"Trying to hotswap an adapter whose weight has {old_val.dim()} dimensions, but only Conv2d and "
"Linear are supported"
)

# Linear or Conv2d: the check for dim 0 or 1 works for both of these layer types
if old_val.shape[0] > new_val.shape[0]:
old_val.data.fill_(0)
old_val.data[: new_val.shape[0]] = new_val.data
Expand All @@ -442,7 +458,7 @@ def hotswap_adapter_from_state_dict(
)


def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None:
def check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None:
"""
Check if two configs are compatible for hot-swapping.

Expand Down Expand Up @@ -548,7 +564,7 @@ def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None,
]
config = config_cls.from_pretrained(model_name_or_path, **kwargs)
# config keys that could affect the model output besides what is determined by the state_dict
_check_hotswap_configs_compatible(model.active_peft_config, config)
check_hotswap_configs_compatible(model.active_peft_config, config)

state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs)

Expand Down
45 changes: 29 additions & 16 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from datasets import Audio, Dataset, DatasetDict, load_dataset
from diffusers import UNet2DConditionModel
from diffusers.utils.testing_utils import floats_tensor
from packaging import version
from parameterized import parameterized
from torch.distributed import init_process_group
Expand Down Expand Up @@ -71,7 +69,7 @@
replace_lora_weights_loftq,
set_peft_model_state_dict,
)
from peft.import_utils import is_xpu_available
from peft.import_utils import is_diffusers_available, is_xpu_available
from peft.tuners import boft
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
Expand Down Expand Up @@ -4237,7 +4235,7 @@ def check_hotswap(self, do_hotswap, ranks, alpha_scalings):
assert torch.allclose(output1, output_after1, atol=tol, rtol=tol)

# it is important to check hotswapping small to large ranks and large to small ranks
@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
@pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self, ranks):
with torch._dynamo.config.patch(error_on_recompile=True): # raise an error on recompilation
self.check_hotswap(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
Expand All @@ -4255,8 +4253,8 @@ def test_no_hotswapping_compiled_model_triggers_recompilation(self):

def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
# TODO: This appears not to work yet in full pipeline context, see:
# https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871
from diffusers import UNet2DConditionModel

torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
Expand All @@ -4273,19 +4271,22 @@ def get_small_unet(self):
model = UNet2DConditionModel(**init_dict)
return model.to(self.torch_device)

def get_unet_lora_config(self, lora_rank, lora_alpha):
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
# note that this only targets linear layers by default
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config

def get_dummy_input(self):
# from UNet2DConditionModelTests
from diffusers.utils.testing_utils import floats_tensor

batch_size = 4
num_channels = 4
sizes = (16, 16)
Expand All @@ -4310,13 +4311,13 @@ def set_lora_device(self, model, adapter_names, device):
device
)

def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings):
def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings, target_modules):
dummy_input = self.get_dummy_input()
unet = self.get_small_unet()
rank0, rank1 = ranks
alpha0, alpha1 = alpha_scalings
lora_config0 = self.get_unet_lora_config(rank0, alpha0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1)
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules=target_modules)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules=target_modules)
unet.add_adapter(lora_config0, adapter_name="adapter0")
unet.add_adapter(lora_config1, adapter_name="adapter1")

Expand All @@ -4337,19 +4338,31 @@ def check_hotswap_diffusion(self, do_hotswap, ranks, alpha_scalings):
unet(**dummy_input)["sample"]

if do_hotswap:
unet.load_lora_adapter(file_name1, adapter_name="default_0", hotswap=True)
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
else:
# offloading the old and loading the new adapter will result in recompilation
self.set_lora_device(unet, adapter_names=["default_0"], device="cpu")
self.set_lora_device(unet, adapter_names=["adapter0"], device="cpu")
unet.load_lora_adapter(file_name1, adapter_name="other_name", hotswap=False)

# we need to call forward to potentially trigger recompilation
unet(**dummy_input)["sample"]

@pytest.mark.skipif(not is_diffusers_available(), reason="Test requires diffusers to be installed")
@pytest.mark.xfail(
strict=True, reason="Requires hotswap to be implemented in diffusers", raises=torch._dynamo.exc.RecompileError
)
def test_hotswapping_compiled_diffusers_model_does_not_trigger_recompilation(self):
ranks = 7, 13
# it is important to check hotswapping small to large ranks and large to small ranks
@pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
@pytest.mark.parametrize(
"target_modules",
[
["to_q", "to_k", "to_v", "to_out.0"], # Linear layers
["conv", "conv1", "conv2"], # Conv2d layers
["to_q", "conv"], # mix of Linear and Conv2d
],
)
def test_hotswapping_compiled_diffusers_model_does_not_trigger_recompilation(self, ranks, target_modules):
with torch._dynamo.config.patch(error_on_recompile=True): # raise an error on recompilation
self.check_hotswap_diffusion(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
self.check_hotswap_diffusion(
do_hotswap=True, ranks=ranks, alpha_scalings=ranks, target_modules=target_modules
)
118 changes: 118 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,124 @@ def test_hotswap_extra_key_raises(self, tmp_path):
with pytest.raises(RuntimeError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas(self, ranks, tmp_path):
# same as test_hotswap_works but different rank and alpha
# Load 2 different adapters and check that we can hotswap between them, with the model optionally being
# compiled.
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["lin0", "lin1"], r=ranks[0], lora_alpha=ranks[0], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["lin0"], r=ranks[1], lora_alpha=ranks[1], init_lora_weights=False)
model = self.get_model()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

@pytest.mark.parametrize("ranks", [(7, 13), (13, 7)])
def test_hotswap_works_different_ranks_alphas_conv2d(self, ranks, tmp_path):
# same as previous test, but for a Conv2d model
atol, rtol = 1e-4, 1e-4
inputs = torch.rand(3, 3, 10, 10).to(self.torch_device)

# create adapter 0
config0 = LoraConfig(target_modules=["conv"], r=ranks[0], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(0)
model = get_peft_model(model, config0)
model.eval()
with torch.inference_mode():
output0 = model(inputs)
model.save_pretrained(tmp_path / "adapter0")

del model

# create adapter 1
config1 = LoraConfig(target_modules=["conv"], r=ranks[1], init_lora_weights=False)
model = self.get_model_conv2d()
torch.manual_seed(1)
model = get_peft_model(model, config1)
model.eval()
with torch.inference_mode():
output1 = model(inputs)
model.save_pretrained(tmp_path / "adapter1")

# sanity check: they're not the same
assert not torch.allclose(output0, output1, atol=atol, rtol=rtol)

del model

# load adapter 0
model = self.get_model_conv2d()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")
with torch.inference_mode():
output_loaded0 = model(inputs)

# sanity check: same output after loading for adapter 0
assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol)

# hotswap with adapter 1
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
with torch.inference_mode():
output_loaded1 = model(inputs)

# real check: model now behaves like adapter 1
assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol)

# hotswap back to adapter 0
hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default")
with torch.inference_mode():
output_loaded_back0 = model(inputs)

# real check: model now behaves again like adapter 0
assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol)

def test_prepare_model_for_compiled_hotswap_scalings_are_tensors(self):
config = LoraConfig(target_modules=["lin0", "lin1"])
model = self.get_model()
Expand Down