|
15 | 15 |
|
16 | 16 | import argparse |
17 | 17 | import copy |
18 | | -import gc |
19 | 18 | import itertools |
20 | 19 | import logging |
21 | 20 | import math |
|
56 | 55 | from diffusers.training_utils import ( |
57 | 56 | _set_state_dict_into_text_encoder, |
58 | 57 | cast_training_params, |
| 58 | + clear_objs_and_retain_memory, |
59 | 59 | compute_density_for_timestep_sampling, |
60 | 60 | compute_loss_weighting_for_sd3, |
61 | 61 | ) |
@@ -210,9 +210,7 @@ def log_validation( |
210 | 210 | } |
211 | 211 | ) |
212 | 212 |
|
213 | | - del pipeline |
214 | | - if torch.cuda.is_available(): |
215 | | - torch.cuda.empty_cache() |
| 213 | + clear_objs_and_retain_memory(objs=[pipeline]) |
216 | 214 |
|
217 | 215 | return images |
218 | 216 |
|
@@ -1107,9 +1105,7 @@ def main(args): |
1107 | 1105 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" |
1108 | 1106 | image.save(image_filename) |
1109 | 1107 |
|
1110 | | - del pipeline |
1111 | | - if torch.cuda.is_available(): |
1112 | | - torch.cuda.empty_cache() |
| 1108 | + clear_objs_and_retain_memory(objs=[pipeline]) |
1113 | 1109 |
|
1114 | 1110 | # Handle the repository creation |
1115 | 1111 | if accelerator.is_main_process: |
@@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): |
1455 | 1451 |
|
1456 | 1452 | # Clear the memory here |
1457 | 1453 | if not args.train_text_encoder and not train_dataset.custom_instance_prompts: |
1458 | | - del tokenizers, text_encoders |
1459 | 1454 | # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection |
1460 | | - del text_encoder_one, text_encoder_two, text_encoder_three |
1461 | | - gc.collect() |
1462 | | - if torch.cuda.is_available(): |
1463 | | - torch.cuda.empty_cache() |
| 1455 | + clear_objs_and_retain_memory( |
| 1456 | + objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] |
| 1457 | + ) |
1464 | 1458 |
|
1465 | 1459 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), |
1466 | 1460 | # pack the statically computed variables appropriately here. This is so that we don't |
@@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1795 | 1789 | pipeline_args=pipeline_args, |
1796 | 1790 | epoch=epoch, |
1797 | 1791 | ) |
| 1792 | + objs = [] |
1798 | 1793 | if not args.train_text_encoder: |
1799 | | - del text_encoder_one, text_encoder_two, text_encoder_three |
| 1794 | + objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) |
1800 | 1795 |
|
1801 | | - torch.cuda.empty_cache() |
1802 | | - gc.collect() |
| 1796 | + clear_objs_and_retain_memory(objs=objs) |
1803 | 1797 |
|
1804 | 1798 | # Save the lora layers |
1805 | 1799 | accelerator.wait_for_everyone() |
|
0 commit comments