@@ -121,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
121121
122122 return super ().state_dict (* args , destination = destination , prefix = prefix , keep_vars = keep_vars )
123123
124- def _fuse_lora (self , lora_scale = 1.0 ):
124+ def _fuse_lora (self , lora_scale = 1.0 , safe_fusing = False ):
125125 if self .lora_linear_layer is None :
126126 return
127127
@@ -135,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
135135 w_up = w_up * self .lora_linear_layer .network_alpha / self .lora_linear_layer .rank
136136
137137 fused_weight = w_orig + (lora_scale * torch .bmm (w_up [None , :], w_down [None , :])[0 ])
138+
139+ if safe_fusing and torch .isnan (fused_weight ).any ().item ():
140+ raise ValueError (
141+ "This LoRA weight seems to be broken. "
142+ f"Encountered NaN values when trying to fuse LoRA weights for { self } ."
143+ "LoRA weights will not be fused."
144+ )
145+
138146 self .regular_linear_layer .weight .data = fused_weight .to (device = device , dtype = dtype )
139147
140148 # we can drop the lora layer now
@@ -672,13 +680,14 @@ def save_function(weights, filename):
672680 save_function (state_dict , os .path .join (save_directory , weight_name ))
673681 logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
674682
675- def fuse_lora (self , lora_scale = 1.0 ):
683+ def fuse_lora (self , lora_scale = 1.0 , safe_fusing = False ):
676684 self .lora_scale = lora_scale
685+ self ._safe_fusing = safe_fusing
677686 self .apply (self ._fuse_lora_apply )
678687
679688 def _fuse_lora_apply (self , module ):
680689 if hasattr (module , "_fuse_lora" ):
681- module ._fuse_lora (self .lora_scale )
690+ module ._fuse_lora (self .lora_scale , self . _safe_fusing )
682691
683692 def unfuse_lora (self ):
684693 self .apply (self ._unfuse_lora_apply )
@@ -2086,7 +2095,13 @@ def unload_lora_weights(self):
20862095 # Safe to call the following regardless of LoRA.
20872096 self ._remove_text_encoder_monkey_patch ()
20882097
2089- def fuse_lora (self , fuse_unet : bool = True , fuse_text_encoder : bool = True , lora_scale : float = 1.0 ):
2098+ def fuse_lora (
2099+ self ,
2100+ fuse_unet : bool = True ,
2101+ fuse_text_encoder : bool = True ,
2102+ lora_scale : float = 1.0 ,
2103+ safe_fusing : bool = False ,
2104+ ):
20902105 r"""
20912106 Fuses the LoRA parameters into the original parameters of the corresponding blocks.
20922107
@@ -2103,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
21032118 LoRA parameters then it won't have any effect.
21042119 lora_scale (`float`, defaults to 1.0):
21052120 Controls how much to influence the outputs with the LoRA parameters.
2121+ safe_fusing (`bool`, defaults to `False`):
2122+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
21062123 """
21072124 if fuse_unet or fuse_text_encoder :
21082125 self .num_fused_loras += 1
@@ -2112,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
21122129 )
21132130
21142131 if fuse_unet :
2115- self .unet .fuse_lora (lora_scale )
2132+ self .unet .fuse_lora (lora_scale , safe_fusing = safe_fusing )
21162133
21172134 if self .use_peft_backend :
21182135 from peft .tuners .tuners_utils import BaseTunerLayer
21192136
2120- def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 ):
2137+ def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 , safe_fusing = False ):
2138+ # TODO(Patrick, Younes): enable "safe" fusing
21212139 for module in text_encoder .modules ():
21222140 if isinstance (module , BaseTunerLayer ):
21232141 if lora_scale != 1.0 :
@@ -2129,24 +2147,24 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
21292147 if version .parse (__version__ ) > version .parse ("0.23" ):
21302148 deprecate ("fuse_text_encoder_lora" , "0.25" , LORA_DEPRECATION_MESSAGE )
21312149
2132- def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 ):
2150+ def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 , safe_fusing = False ):
21332151 for _ , attn_module in text_encoder_attn_modules (text_encoder ):
21342152 if isinstance (attn_module .q_proj , PatchedLoraProjection ):
2135- attn_module .q_proj ._fuse_lora (lora_scale )
2136- attn_module .k_proj ._fuse_lora (lora_scale )
2137- attn_module .v_proj ._fuse_lora (lora_scale )
2138- attn_module .out_proj ._fuse_lora (lora_scale )
2153+ attn_module .q_proj ._fuse_lora (lora_scale , safe_fusing )
2154+ attn_module .k_proj ._fuse_lora (lora_scale , safe_fusing )
2155+ attn_module .v_proj ._fuse_lora (lora_scale , safe_fusing )
2156+ attn_module .out_proj ._fuse_lora (lora_scale , safe_fusing )
21392157
21402158 for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
21412159 if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
2142- mlp_module .fc1 ._fuse_lora (lora_scale )
2143- mlp_module .fc2 ._fuse_lora (lora_scale )
2160+ mlp_module .fc1 ._fuse_lora (lora_scale , safe_fusing )
2161+ mlp_module .fc2 ._fuse_lora (lora_scale , safe_fusing )
21442162
21452163 if fuse_text_encoder :
21462164 if hasattr (self , "text_encoder" ):
2147- fuse_text_encoder_lora (self .text_encoder , lora_scale )
2165+ fuse_text_encoder_lora (self .text_encoder , lora_scale , safe_fusing )
21482166 if hasattr (self , "text_encoder_2" ):
2149- fuse_text_encoder_lora (self .text_encoder_2 , lora_scale )
2167+ fuse_text_encoder_lora (self .text_encoder_2 , lora_scale , safe_fusing )
21502168
21512169 def unfuse_lora (self , unfuse_unet : bool = True , unfuse_text_encoder : bool = True ):
21522170 r"""
0 commit comments