Skip to content

load_best_model_at_end doesn't work if load_in_8bit=True #394

@ChrisAGBlake

Description

@ChrisAGBlake

If I fine tune a flan-t5 model using LoRA I can't seem to use load_best_model_at_end=True

model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-large', use_cache=False, load_in_8bit=True, device_map='auto')
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
    r=16, 
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, lora_config)

...

args = Seq2SeqTrainingArguments(
    output_dir=model_dir,
    evaluation_strategy = "epoch",
    per_device_train_batch_size=cfg.batch_size,
    per_device_eval_batch_size=cfg.batch_size,
    weight_decay=cfg.weight_decay,
    num_train_epochs=cfg.n_epochs,
    save_strategy="epoch",
    predict_with_generate=True,
    load_best_model_at_end=True
)

This will fail at trainer.train() after the training has finished and it is loading the checkpoint with the best validation accuracy to save this one.

The error is:
RuntimeError: Loading a quantized checkpoint into non-quantized Linear8bitLt is not supported. Please call module.cuda() before module.load_state_dict()

This is on Ubuntu 22.04 with an NVIDIA 4090

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions