-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
If I fine tune a flan-t5 model using LoRA I can't seem to use load_best_model_at_end=True
model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-large', use_cache=False, load_in_8bit=True, device_map='auto')
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, lora_config)
...
args = Seq2SeqTrainingArguments(
output_dir=model_dir,
evaluation_strategy = "epoch",
per_device_train_batch_size=cfg.batch_size,
per_device_eval_batch_size=cfg.batch_size,
weight_decay=cfg.weight_decay,
num_train_epochs=cfg.n_epochs,
save_strategy="epoch",
predict_with_generate=True,
load_best_model_at_end=True
)This will fail at trainer.train() after the training has finished and it is loading the checkpoint with the best validation accuracy to save this one.
The error is:
RuntimeError: Loading a quantized checkpoint into non-quantized Linear8bitLt is not supported. Please call module.cuda() before module.load_state_dict()
This is on Ubuntu 22.04 with an NVIDIA 4090
wang-yiwei, Taytay, NanoCode012, akkikiki, jyx-su and 2 more
Metadata
Metadata
Assignees
Labels
No labels