diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py index b57e776e49eb..e1f508dc1e36 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -238,14 +238,14 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): def _generate( self, - prompt_ids: jnp.array, - image: jnp.array, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, controlnet_conditioning_scale: float = 1.0, ): height, width = image.shape[-2:] @@ -348,15 +348,15 @@ def loop_body(step, args): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt_ids: jnp.array, - image: jnp.array, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, - controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0, return_dict: bool = True, jit: bool = False, ): @@ -364,13 +364,13 @@ def __call__( The call function to the pipeline for generation. Args: - prompt_ids (`jnp.array`): + prompt_ids (`jnp.ndarray`): The prompt or prompts to guide the image generation. - image (`jnp.array`): + image (`jnp.ndarray`): Array representing the ControlNet input condition to provide guidance to the `unet` for generation. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. - prng_seed (`jax.Array` or `jax.Array`): + prng_seed (`jax.Array`): Array containing random number generator key. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -378,11 +378,11 @@ def __call__( guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - latents (`jnp.array`, *optional*): + latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`. - controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. return_dict (`bool`, *optional*, defaults to `True`): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index a847cd15c6ce..bcf2a6217772 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -220,8 +220,8 @@ def _generate( height: int, width: int, guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -316,9 +316,9 @@ def __call__( num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): @@ -338,7 +338,7 @@ def __call__( guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - latents (`jnp.array`, *optional*): + latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 42a79db6b2b2..c1fd310ea582 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -232,8 +232,8 @@ def get_timestep_start(self, num_inference_steps, strength): def _generate( self, - prompt_ids: jnp.array, - image: jnp.array, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, start_timestep: int, @@ -241,8 +241,8 @@ def _generate( height: int, width: int, guidance_scale: float, - noise: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, + noise: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -337,17 +337,17 @@ def loop_body(step, args): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt_ids: jnp.array, - image: jnp.array, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - noise: jnp.array = None, - neg_prompt_ids: jnp.array = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + noise: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): @@ -355,9 +355,9 @@ def __call__( The call function to the pipeline for generation. Args: - prompt_ids (`jnp.array`): + prompt_ids (`jnp.ndarray`): The prompt or prompts to guide image generation. - image (`jnp.array`): + image (`jnp.ndarray`): Array representing an image batch to be used as the starting point. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. @@ -379,7 +379,7 @@ def __call__( guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - noise (`jnp.array`, *optional*): + noise (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. The array is generated by sampling using the supplied random `generator`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index 153267da1067..b9a2331a061c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -266,17 +266,17 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): def _generate( self, - prompt_ids: jnp.array, - mask: jnp.array, - masked_image: jnp.array, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, + latents: Optional[jnp.ndarray] = None, + neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -394,17 +394,17 @@ def loop_body(step, args): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt_ids: jnp.array, - mask: jnp.array, - masked_image: jnp.array, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, + guidance_scale: Union[float, jnp.ndarray] = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): @@ -424,7 +424,7 @@ def __call__( guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - latents (`jnp.array`, *optional*): + latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`.