|
67 | 67 | convert_state_dict_to_kohya, |
68 | 68 | is_wandb_available, |
69 | 69 | ) |
| 70 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
70 | 71 | from diffusers.utils.import_utils import is_xformers_available |
71 | 72 |
|
72 | 73 |
|
|
79 | 80 | def save_model_card( |
80 | 81 | repo_id: str, |
81 | 82 | use_dora: bool, |
82 | | - images=None, |
83 | | - base_model=str, |
| 83 | + images: list = None, |
| 84 | + base_model: str = None, |
84 | 85 | train_text_encoder=False, |
85 | 86 | train_text_encoder_ti=False, |
86 | 87 | token_abstraction_dict=None, |
87 | | - instance_prompt=str, |
88 | | - validation_prompt=str, |
| 88 | + instance_prompt=None, |
| 89 | + validation_prompt=None, |
89 | 90 | repo_folder=None, |
90 | 91 | vae_path=None, |
91 | 92 | ): |
92 | | - img_str = "widget:\n" |
93 | 93 | lora = "lora" if not use_dora else "dora" |
94 | | - for i, image in enumerate(images): |
95 | | - image.save(os.path.join(repo_folder, f"image_{i}.png")) |
96 | | - img_str += f""" |
97 | | - - text: '{validation_prompt if validation_prompt else ' ' }' |
98 | | - output: |
99 | | - url: |
100 | | - "image_{i}.png" |
101 | | - """ |
102 | | - if not images: |
103 | | - img_str += f""" |
104 | | - - text: '{instance_prompt}' |
105 | | - """ |
| 94 | + |
| 95 | + widget_dict = [] |
| 96 | + if images is not None: |
| 97 | + for i, image in enumerate(images): |
| 98 | + image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| 99 | + widget_dict.append( |
| 100 | + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} |
| 101 | + ) |
| 102 | + else: |
| 103 | + widget_dict.append({"text": instance_prompt}) |
106 | 104 | embeddings_filename = f"{repo_folder}_emb" |
107 | 105 | instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1)) |
108 | 106 | ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt)) |
@@ -137,24 +135,7 @@ def save_model_card( |
137 | 135 | trigger_str += f""" |
138 | 136 | to trigger concept `{key}` → use `{tokens}` in your prompt \n |
139 | 137 | """ |
140 | | - |
141 | | - yaml = f"""--- |
142 | | -tags: |
143 | | -- stable-diffusion |
144 | | -- stable-diffusion-diffusers |
145 | | -- diffusers-training |
146 | | -- text-to-image |
147 | | -- diffusers |
148 | | -- {lora} |
149 | | -- template:sd-lora |
150 | | -{img_str} |
151 | | -base_model: {base_model} |
152 | | -instance_prompt: {instance_prompt} |
153 | | -license: openrail++ |
154 | | ---- |
155 | | -""" |
156 | | - |
157 | | - model_card = f""" |
| 138 | + model_description = f""" |
158 | 139 | # SD1.5 LoRA DreamBooth - {repo_id} |
159 | 140 |
|
160 | 141 | <Gallery /> |
@@ -202,8 +183,28 @@ def save_model_card( |
202 | 183 | Special VAE used for training: {vae_path}. |
203 | 184 |
|
204 | 185 | """ |
205 | | - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
206 | | - f.write(yaml + model_card) |
| 186 | + model_card = load_or_create_model_card( |
| 187 | + repo_id_or_path=repo_id, |
| 188 | + from_training=True, |
| 189 | + license="openrail++", |
| 190 | + base_model=base_model, |
| 191 | + prompt=instance_prompt, |
| 192 | + model_description=model_description, |
| 193 | + inference=True, |
| 194 | + widget=widget_dict, |
| 195 | + ) |
| 196 | + |
| 197 | + tags = [ |
| 198 | + "text-to-image", |
| 199 | + "diffusers", |
| 200 | + "diffusers-training", |
| 201 | + lora, |
| 202 | + "template:sd-lora" "stable-diffusion", |
| 203 | + "stable-diffusion-diffusers", |
| 204 | + ] |
| 205 | + model_card = populate_model_card(model_card, tags=tags) |
| 206 | + |
| 207 | + model_card.save(os.path.join(repo_folder, "README.md")) |
207 | 208 |
|
208 | 209 |
|
209 | 210 | def import_model_class_from_model_name_or_path( |
|
0 commit comments