11import logging
2+ import warnings
23import os
34from functools import partial
45from pathlib import Path
56from types import MethodType
67from typing import Callable , Dict , Iterator , List , Optional , Tuple , Dict
78
8- from peft import LoraConfig , TaskType , get_peft_model
9-
109import torch
1110import torch .nn as nn
11+ from torch .nn import Parameter
1212from torch .optim import Optimizer
1313from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
1414from torch .utils ._pytree import tree_map
@@ -335,13 +335,44 @@ def enable_lora(
335335 from peft import PeftModel , get_peft_model
336336 assert not isinstance (model , LowLevelZeroModel ), "Lora should be enabled before boosting the model."
337337 self .lora_enabled = True
338+ warnings .warn ("You have enabled LoRa training. Please check the hyperparameter such as lr" )
338339
339340 if pretrained_dir is None :
340341 peft_model = get_peft_model (model , lora_config )
341342 else :
342343 peft_model = PeftModel .from_pretrained (model , pretrained_dir , is_trainable = True )
343344 return peft_model
344345
346+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter ):
347+ origin_param_id = id (origin_param )
348+ for group_id , param_group in enumerate (optimizer .param_groups ):
349+ for p in param_group ['params' ]:
350+ if id (p ) == origin_param_id :
351+ return group_id
352+ return - 1
353+
354+ def add_lora_para_to_optimizer (self , model , optimizer ):
355+ """ add lora parameters to optimizer """
356+ name2param = {}
357+ for name , param in model .named_parameters ():
358+ name2param [name ] = param
359+
360+ optimizer_param_nums = 0
361+ for param_group in optimizer .param_groups :
362+ optimizer_param_nums += len (param_group ['params' ])
363+
364+ # Check if the optimizer is created after the model is transformed into a LoRa model.
365+ if len (name2param ) != optimizer_param_nums :
366+ for name , param in name2param .items ():
367+ if 'lora_A' in name or 'lora_B' in name :
368+ origin_key = name .replace ("lora_A." , "" )
369+ origin_key = origin_key .replace ("lora_B." , "" )
370+ origin_key = origin_key .replace (f"{ model .active_adapter } ." , "" )
371+ origin_param = name2param [origin_key ]
372+ group_id = self .get_param_group_id (optimizer , origin_param )
373+ assert group_id != - 1 , "Parameter error, origin parameter does't exists."
374+ optimizer .param_groups [group_id ]['params' ].append (param )
375+
345376 def configure (
346377 self ,
347378 model : nn .Module ,
@@ -353,12 +384,8 @@ def configure(
353384 if self .lora_enabled :
354385 from peft import PeftModel
355386 assert isinstance (model , PeftModel ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
356-
357- optim_params_nums = 0
358- for param_group in optimizer .param_groups :
359- optim_params_nums += len (param_group ['params' ])
360- model_params_nums = len (list (model .named_parameters ()))
361- assert optim_params_nums == model_params_nums , "Optimizer should be initialized after enabling lora."
387+ self .add_lora_para_to_optimizer (model , optimizer )
388+
362389
363390 if not isinstance (model , ModelWrapper ):
364391 model = LowLevelZeroModel (model , self .precision )
0 commit comments