Skip to content

LoRA adapters: non revertible fuse #5313

@oOraph

Description

@oOraph

Describe the bug

When we load certain adapters, a black image is generated.

During inference the following noticeable error occurs:

/lib/python3.10/site-packages/diffusers/image_processor.py:88: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

This is probably an issue too, but the main issue is that if we fuse the faulty adapter with the base model for a faster inference, then the unfuse does not reverse the fuse completely, leading to black images generated for all the following adapters

Reproduced with the script provided below, on diffusers@main branch on T4 and A10G gpus, not reproduced if we remove the fuse/unfuse lines (only adapter2 gets a black screen)

Reproduction

import torch


from diffusers import (
    AutoencoderKL,
    DiffusionPipeline,
)
import hashlib

base = "stabilityai/stable-diffusion-xl-base-1.0"
adapter1 = 'nerijs/pixel-art-xl'
weightname1 = 'pixel-art-xl.safetensors'

adapter2 = 'Alexzyx/lora-trained-xl-colab'
weightname2 = None

import pdb
pdb.set_trace()

inputs = "elephant"
kwargs = {}

if torch.cuda.is_available():
    kwargs["torch_dtype"] = torch.float16

#vae = AutoencoderKL.from_pretrained(
#    "madebyollin/sdxl-vae-fp16-fix",
#    torch_dtype=torch.float16,  # load fp16 fix VAE
#)
#kwargs["vae"] = vae
#kwargs["variant"] = "fp16"
#

model = DiffusionPipeline.from_pretrained(
    base, **kwargs
)

if torch.cuda.is_available():
    model.to("cuda")


def inference(adapter, weightname):
    model.load_lora_weights(adapter, weight_name=weightname)
    model.fuse_lora()
    data = model(inputs, num_inference_steps=1).images[0]
    model.unfuse_lora()
    model.unload_lora_weights()
    filename = '/tmp/hello.jpg'
    data.save(filename, format='jpeg')
    with open(filename, 'rb') as f:
        md5 = hashlib.md5(f.read()).hexdigest()
    print("Adapter %s, md5sum %s" % (adapter, md5))
    if md5 == '40c78c9fd4daeff01c988c3532fdd51b':
        print("BLACK SCREEN IMAGE for adapter %s" % adapter)


inference(adapter1, weightname1)
inference(adapter2, weightname2)
inference(adapter1, weightname1)
inference(adapter1, weightname1)

Logs

Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:41<00:00,  5.95s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.54s/it]
Adapter nerijs/pixel-art-xl, md5sum 7906455789f5fdfd94f0793d4c026563
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.63it/s]
/home/drive/venvs/api-inference-community/lib/python3.10/site-packages/diffusers/image_processor.py:88: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")
Adapter Alexzyx/lora-trained-xl-colab, md5sum 40c78c9fd4daeff01c988c3532fdd51b
BLACK SCREEN IMAGE for adapter Alexzyx/lora-trained-xl-colab
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.62it/s]
Adapter nerijs/pixel-art-xl, md5sum 40c78c9fd4daeff01c988c3532fdd51b
BLACK SCREEN IMAGE for adapter nerijs/pixel-art-xl
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.61it/s]
Adapter nerijs/pixel-art-xl, md5sum 40c78c9fd4daeff01c988c3532fdd51b
BLACK SCREEN IMAGE for adapter nerijs/pixel-art-xl


### System Info

- `diffusers` version: 0.22.0.dev0
- Platform: Linux-6.2.0-1012-aws-x86_64-with-glibc2.35
- Python version: 3.10.12
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Huggingface_hub version: 0.17.2
- Transformers version: 4.31.0
- Accelerate version: 0.21.0
- xFormers version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no

### Who can help?

@sayakpaul @patrickvonplaten 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions