2727from invokeai .backend .stable_diffusion .diffusers_pipeline import image_resized_to_grid_as_tensor
2828from invokeai .backend .stable_diffusion .vae_tiling import patch_vae_tiling_params
2929from invokeai .backend .util .devices import TorchDevice
30+ from invokeai .backend .util .vae_working_memory import estimate_vae_working_memory_sd15_sdxl
3031
3132
3233@invocation (
@@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation):
5253 tile_size : int = InputField (default = 0 , multiple_of = 8 , description = FieldDescriptions .vae_tile_size )
5354 fp32 : bool = InputField (default = False , description = FieldDescriptions .fp32 )
5455
55- def _estimate_working_memory (
56- self , image_tensor : torch .Tensor , use_tiling : bool , vae : AutoencoderKL | AutoencoderTiny
57- ) -> int :
58- """Estimate the working memory required by the invocation in bytes."""
59- # Encode operations use approximately 50% of the memory required for decode operations
60- element_size = 4 if self .fp32 else 2
61- scaling_constant = 1100 # 50% of decode scaling constant (2200)
62-
63- if use_tiling :
64- tile_size = self .tile_size
65- if tile_size == 0 :
66- tile_size = vae .tile_sample_min_size
67- assert isinstance (tile_size , int )
68- h = tile_size
69- w = tile_size
70- working_memory = h * w * element_size * scaling_constant
71-
72- # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
73- # and number of tiles. We could make this more precise in the future, but this should be good enough for
74- # most use cases.
75- working_memory = working_memory * 1.25
76- else :
77- h = image_tensor .shape [- 2 ]
78- w = image_tensor .shape [- 1 ]
79- working_memory = h * w * element_size * scaling_constant
80-
81- if self .fp32 :
82- # If we are running in FP32, then we should account for the likely increase in model size (~250MB).
83- working_memory += 250 * 2 ** 20
84-
85- return int (working_memory )
86-
87- @staticmethod
56+ @classmethod
8857 def vae_encode (
58+ cls ,
8959 vae_info : LoadedModel ,
9060 upcast : bool ,
9161 tiled : bool ,
9262 image_tensor : torch .Tensor ,
9363 tile_size : int = 0 ,
94- estimated_working_memory : int = 0 ,
9564 ) -> torch .Tensor :
65+ assert isinstance (vae_info .model , (AutoencoderKL , AutoencoderTiny ))
66+ estimated_working_memory = estimate_vae_working_memory_sd15_sdxl (
67+ operation = "encode" ,
68+ image_tensor = image_tensor ,
69+ vae = vae_info .model ,
70+ tile_size = tile_size if tiled else None ,
71+ fp32 = upcast ,
72+ )
9673 with vae_info .model_on_device (working_mem_bytes = estimated_working_memory ) as (_ , vae ):
9774 assert isinstance (vae , (AutoencoderKL , AutoencoderTiny ))
9875 orig_dtype = vae .dtype
@@ -156,17 +133,13 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
156133 if image_tensor .dim () == 3 :
157134 image_tensor = einops .rearrange (image_tensor , "c h w -> 1 c h w" )
158135
159- use_tiling = self .tiled or context .config .get ().force_tiled_decode
160- estimated_working_memory = self ._estimate_working_memory (image_tensor , use_tiling , vae_info .model )
161-
162136 context .util .signal_progress ("Running VAE encoder" )
163137 latents = self .vae_encode (
164138 vae_info = vae_info ,
165139 upcast = self .fp32 ,
166- tiled = self .tiled ,
140+ tiled = self .tiled or context . config . get (). force_tiled_decode ,
167141 image_tensor = image_tensor ,
168142 tile_size = self .tile_size ,
169- estimated_working_memory = estimated_working_memory ,
170143 )
171144
172145 latents = latents .to ("cpu" )
0 commit comments