diff --git a/scripts/dpo.py b/scripts/dpo.py index 07df54c7..a5c3168d 100644 --- a/scripts/dpo.py +++ b/scripts/dpo.py @@ -98,7 +98,10 @@ def main(script_args, training_args, model_args): # Model & Tokenizer ################### model = get_model(model_args, training_args) - ref_model = get_model(model_args, training_args) + if model_args.use_peft: + ref_model = None + else: + ref_model = get_model(model_args, training_args) tokenizer = get_tokenizer(model_args, training_args) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token