Skip to content

Commit 70afa83

Browse files
patrickvonplatenBenjaminBossan
authored andcommitted
Fix loading broken LoRAs that could give NaN (huggingface#5316)
* Fix fuse Lora * improve a bit * make style * Update src/diffusers/models/lora.py Co-authored-by: Benjamin Bossan <[email protected]> * ciao C file * ciao C file * test & make style --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent bc81e79 commit 70afa83

File tree

3 files changed

+92
-17
lines changed

3 files changed

+92
-17
lines changed

src/diffusers/loaders.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
121121

122122
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
123123

124-
def _fuse_lora(self, lora_scale=1.0):
124+
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
125125
if self.lora_linear_layer is None:
126126
return
127127

@@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
135135
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
136136

137137
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
138+
139+
if safe_fusing and torch.isnan(fused_weight).any().item():
140+
raise ValueError(
141+
"This LoRA weight seems to be broken. "
142+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
143+
"LoRA weights will not be fused."
144+
)
145+
138146
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
139147

140148
# we can drop the lora layer now
@@ -672,13 +680,14 @@ def save_function(weights, filename):
672680
save_function(state_dict, os.path.join(save_directory, weight_name))
673681
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
674682

675-
def fuse_lora(self, lora_scale=1.0):
683+
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
676684
self.lora_scale = lora_scale
685+
self._safe_fusing = safe_fusing
677686
self.apply(self._fuse_lora_apply)
678687

679688
def _fuse_lora_apply(self, module):
680689
if hasattr(module, "_fuse_lora"):
681-
module._fuse_lora(self.lora_scale)
690+
module._fuse_lora(self.lora_scale, self._safe_fusing)
682691

683692
def unfuse_lora(self):
684693
self.apply(self._unfuse_lora_apply)
@@ -2086,7 +2095,13 @@ def unload_lora_weights(self):
20862095
# Safe to call the following regardless of LoRA.
20872096
self._remove_text_encoder_monkey_patch()
20882097

2089-
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
2098+
def fuse_lora(
2099+
self,
2100+
fuse_unet: bool = True,
2101+
fuse_text_encoder: bool = True,
2102+
lora_scale: float = 1.0,
2103+
safe_fusing: bool = False,
2104+
):
20902105
r"""
20912106
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
20922107
@@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
21032118
LoRA parameters then it won't have any effect.
21042119
lora_scale (`float`, defaults to 1.0):
21052120
Controls how much to influence the outputs with the LoRA parameters.
2121+
safe_fusing (`bool`, defaults to `False`):
2122+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
21062123
"""
21072124
if fuse_unet or fuse_text_encoder:
21082125
self.num_fused_loras += 1
@@ -2112,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
21122129
)
21132130

21142131
if fuse_unet:
2115-
self.unet.fuse_lora(lora_scale)
2132+
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
21162133

21172134
if self.use_peft_backend:
21182135
from peft.tuners.tuners_utils import BaseTunerLayer
21192136

2120-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
2137+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
2138+
# TODO(Patrick, Younes): enable "safe" fusing
21212139
for module in text_encoder.modules():
21222140
if isinstance(module, BaseTunerLayer):
21232141
if lora_scale != 1.0:
@@ -2129,24 +2147,24 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
21292147
if version.parse(__version__) > version.parse("0.23"):
21302148
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
21312149

2132-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
2150+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
21332151
for _, attn_module in text_encoder_attn_modules(text_encoder):
21342152
if isinstance(attn_module.q_proj, PatchedLoraProjection):
2135-
attn_module.q_proj._fuse_lora(lora_scale)
2136-
attn_module.k_proj._fuse_lora(lora_scale)
2137-
attn_module.v_proj._fuse_lora(lora_scale)
2138-
attn_module.out_proj._fuse_lora(lora_scale)
2153+
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
2154+
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
2155+
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
2156+
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
21392157

21402158
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
21412159
if isinstance(mlp_module.fc1, PatchedLoraProjection):
2142-
mlp_module.fc1._fuse_lora(lora_scale)
2143-
mlp_module.fc2._fuse_lora(lora_scale)
2160+
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
2161+
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
21442162

21452163
if fuse_text_encoder:
21462164
if hasattr(self, "text_encoder"):
2147-
fuse_text_encoder_lora(self.text_encoder, lora_scale)
2165+
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
21482166
if hasattr(self, "text_encoder_2"):
2149-
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
2167+
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
21502168

21512169
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
21522170
r"""

src/diffusers/models/lora.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
112112
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
113113
self.lora_layer = lora_layer
114114

115-
def _fuse_lora(self, lora_scale=1.0):
115+
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
116116
if self.lora_layer is None:
117117
return
118118

@@ -128,6 +128,14 @@ def _fuse_lora(self, lora_scale=1.0):
128128
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
129129
fusion = fusion.reshape((w_orig.shape))
130130
fused_weight = w_orig + (lora_scale * fusion)
131+
132+
if safe_fusing and torch.isnan(fused_weight).any().item():
133+
raise ValueError(
134+
"This LoRA weight seems to be broken. "
135+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
136+
"LoRA weights will not be fused."
137+
)
138+
131139
self.weight.data = fused_weight.to(device=device, dtype=dtype)
132140

133141
# we can drop the lora layer now
@@ -182,7 +190,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
182190
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
183191
self.lora_layer = lora_layer
184192

185-
def _fuse_lora(self, lora_scale=1.0):
193+
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
186194
if self.lora_layer is None:
187195
return
188196

@@ -196,6 +204,14 @@ def _fuse_lora(self, lora_scale=1.0):
196204
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
197205

198206
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
207+
208+
if safe_fusing and torch.isnan(fused_weight).any().item():
209+
raise ValueError(
210+
"This LoRA weight seems to be broken. "
211+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
212+
"LoRA weights will not be fused."
213+
)
214+
199215
self.weight.data = fused_weight.to(device=device, dtype=dtype)
200216

201217
# we can drop the lora layer now

tests/lora/test_lora_layers_old_backend.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,47 @@ def test_load_lora_locally_safetensors(self):
10281028

10291029
sd_pipe.unload_lora_weights()
10301030

1031+
def test_lora_fuse_nan(self):
1032+
pipeline_components, lora_components = self.get_dummy_components()
1033+
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
1034+
sd_pipe = sd_pipe.to(torch_device)
1035+
sd_pipe.set_progress_bar_config(disable=None)
1036+
1037+
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
1038+
1039+
# Emulate training.
1040+
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
1041+
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
1042+
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
1043+
1044+
with tempfile.TemporaryDirectory() as tmpdirname:
1045+
StableDiffusionXLPipeline.save_lora_weights(
1046+
save_directory=tmpdirname,
1047+
unet_lora_layers=lora_components["unet_lora_layers"],
1048+
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
1049+
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
1050+
safe_serialization=True,
1051+
)
1052+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
1053+
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
1054+
1055+
# corrupt one LoRA weight with `inf` values
1056+
with torch.no_grad():
1057+
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
1058+
"inf"
1059+
)
1060+
1061+
# with `safe_fusing=True` we should see an Error
1062+
with self.assertRaises(ValueError):
1063+
sd_pipe.fuse_lora(safe_fusing=True)
1064+
1065+
# without we should not see an error, but every image will be black
1066+
sd_pipe.fuse_lora(safe_fusing=False)
1067+
1068+
out = sd_pipe("test", num_inference_steps=2, output_type="np").images
1069+
1070+
assert np.isnan(out).all()
1071+
10311072
def test_lora_fusion(self):
10321073
pipeline_components, lora_components = self.get_dummy_components()
10331074
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)

0 commit comments

Comments
 (0)