Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):

def _generate(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
controlnet_conditioning_scale: float = 1.0,
):
height, width = image.shape[-2:]
Expand Down Expand Up @@ -348,41 +348,41 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
return_dict: bool = True,
jit: bool = False,
):
r"""
The call function to the pipeline for generation.

Args:
prompt_ids (`jnp.array`):
prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide the image generation.
image (`jnp.array`):
image (`jnp.ndarray`):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.Array` or `jax.Array`):
prng_seed (`jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def _generate(
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -316,9 +316,9 @@ def __call__(
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
Expand All @@ -338,7 +338,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,17 @@ def get_timestep_start(self, num_inference_steps, strength):

def _generate(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
start_timestep: int,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
noise: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
noise: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -337,27 +337,27 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
noise: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
noise: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
r"""
The call function to the pipeline for generation.

Args:
prompt_ids (`jnp.array`):
prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide image generation.
image (`jnp.array`):
image (`jnp.ndarray`):
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
Expand All @@ -379,7 +379,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
noise (`jnp.array`, *optional*):
noise (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. The array is generated by
sampling using the supplied random `generator`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):

def _generate(
self,
prompt_ids: jnp.array,
mask: jnp.array,
masked_image: jnp.array,
prompt_ids: jnp.ndarray,
mask: jnp.ndarray,
masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -394,17 +394,17 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
mask: jnp.array,
masked_image: jnp.array,
prompt_ids: jnp.ndarray,
mask: jnp.ndarray,
masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
Expand All @@ -424,7 +424,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
Expand Down