Skip to content

Commit 73fddaf

Browse files
hvaarahawkinsp
andauthored
[JAX] Replace uses of jnp.array in types with jnp.ndarray. (huggingface#4719)
`jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`. Co-authored-by: Peter Hawkins <[email protected]>
1 parent ab08277 commit 73fddaf

File tree

4 files changed

+45
-45
lines changed

4 files changed

+45
-45
lines changed

pipelines/controlnet/pipeline_flax_controlnet.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,14 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):
238238

239239
def _generate(
240240
self,
241-
prompt_ids: jnp.array,
242-
image: jnp.array,
241+
prompt_ids: jnp.ndarray,
242+
image: jnp.ndarray,
243243
params: Union[Dict, FrozenDict],
244244
prng_seed: jax.Array,
245245
num_inference_steps: int,
246246
guidance_scale: float,
247-
latents: Optional[jnp.array] = None,
248-
neg_prompt_ids: Optional[jnp.array] = None,
247+
latents: Optional[jnp.ndarray] = None,
248+
neg_prompt_ids: Optional[jnp.ndarray] = None,
249249
controlnet_conditioning_scale: float = 1.0,
250250
):
251251
height, width = image.shape[-2:]
@@ -348,41 +348,41 @@ def loop_body(step, args):
348348
@replace_example_docstring(EXAMPLE_DOC_STRING)
349349
def __call__(
350350
self,
351-
prompt_ids: jnp.array,
352-
image: jnp.array,
351+
prompt_ids: jnp.ndarray,
352+
image: jnp.ndarray,
353353
params: Union[Dict, FrozenDict],
354354
prng_seed: jax.Array,
355355
num_inference_steps: int = 50,
356-
guidance_scale: Union[float, jnp.array] = 7.5,
357-
latents: jnp.array = None,
358-
neg_prompt_ids: jnp.array = None,
359-
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
356+
guidance_scale: Union[float, jnp.ndarray] = 7.5,
357+
latents: jnp.ndarray = None,
358+
neg_prompt_ids: jnp.ndarray = None,
359+
controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
360360
return_dict: bool = True,
361361
jit: bool = False,
362362
):
363363
r"""
364364
The call function to the pipeline for generation.
365365
366366
Args:
367-
prompt_ids (`jnp.array`):
367+
prompt_ids (`jnp.ndarray`):
368368
The prompt or prompts to guide the image generation.
369-
image (`jnp.array`):
369+
image (`jnp.ndarray`):
370370
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
371371
params (`Dict` or `FrozenDict`):
372372
Dictionary containing the model parameters/weights.
373-
prng_seed (`jax.Array` or `jax.Array`):
373+
prng_seed (`jax.Array`):
374374
Array containing random number generator key.
375375
num_inference_steps (`int`, *optional*, defaults to 50):
376376
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
377377
expense of slower inference.
378378
guidance_scale (`float`, *optional*, defaults to 7.5):
379379
A higher guidance scale value encourages the model to generate images closely linked to the text
380380
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
381-
latents (`jnp.array`, *optional*):
381+
latents (`jnp.ndarray`, *optional*):
382382
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
383383
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
384384
array is generated by sampling using the supplied random `generator`.
385-
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
385+
controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
386386
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
387387
to the residual in the original `unet`.
388388
return_dict (`bool`, *optional*, defaults to `True`):

pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def _generate(
220220
height: int,
221221
width: int,
222222
guidance_scale: float,
223-
latents: Optional[jnp.array] = None,
224-
neg_prompt_ids: Optional[jnp.array] = None,
223+
latents: Optional[jnp.ndarray] = None,
224+
neg_prompt_ids: Optional[jnp.ndarray] = None,
225225
):
226226
if height % 8 != 0 or width % 8 != 0:
227227
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -316,9 +316,9 @@ def __call__(
316316
num_inference_steps: int = 50,
317317
height: Optional[int] = None,
318318
width: Optional[int] = None,
319-
guidance_scale: Union[float, jnp.array] = 7.5,
320-
latents: jnp.array = None,
321-
neg_prompt_ids: jnp.array = None,
319+
guidance_scale: Union[float, jnp.ndarray] = 7.5,
320+
latents: jnp.ndarray = None,
321+
neg_prompt_ids: jnp.ndarray = None,
322322
return_dict: bool = True,
323323
jit: bool = False,
324324
):
@@ -338,7 +338,7 @@ def __call__(
338338
guidance_scale (`float`, *optional*, defaults to 7.5):
339339
A higher guidance scale value encourages the model to generate images closely linked to the text
340340
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
341-
latents (`jnp.array`, *optional*):
341+
latents (`jnp.ndarray`, *optional*):
342342
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
343343
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
344344
array is generated by sampling using the supplied random `generator`.

pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,17 +232,17 @@ def get_timestep_start(self, num_inference_steps, strength):
232232

233233
def _generate(
234234
self,
235-
prompt_ids: jnp.array,
236-
image: jnp.array,
235+
prompt_ids: jnp.ndarray,
236+
image: jnp.ndarray,
237237
params: Union[Dict, FrozenDict],
238238
prng_seed: jax.Array,
239239
start_timestep: int,
240240
num_inference_steps: int,
241241
height: int,
242242
width: int,
243243
guidance_scale: float,
244-
noise: Optional[jnp.array] = None,
245-
neg_prompt_ids: Optional[jnp.array] = None,
244+
noise: Optional[jnp.ndarray] = None,
245+
neg_prompt_ids: Optional[jnp.ndarray] = None,
246246
):
247247
if height % 8 != 0 or width % 8 != 0:
248248
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -337,27 +337,27 @@ def loop_body(step, args):
337337
@replace_example_docstring(EXAMPLE_DOC_STRING)
338338
def __call__(
339339
self,
340-
prompt_ids: jnp.array,
341-
image: jnp.array,
340+
prompt_ids: jnp.ndarray,
341+
image: jnp.ndarray,
342342
params: Union[Dict, FrozenDict],
343343
prng_seed: jax.Array,
344344
strength: float = 0.8,
345345
num_inference_steps: int = 50,
346346
height: Optional[int] = None,
347347
width: Optional[int] = None,
348-
guidance_scale: Union[float, jnp.array] = 7.5,
349-
noise: jnp.array = None,
350-
neg_prompt_ids: jnp.array = None,
348+
guidance_scale: Union[float, jnp.ndarray] = 7.5,
349+
noise: jnp.ndarray = None,
350+
neg_prompt_ids: jnp.ndarray = None,
351351
return_dict: bool = True,
352352
jit: bool = False,
353353
):
354354
r"""
355355
The call function to the pipeline for generation.
356356
357357
Args:
358-
prompt_ids (`jnp.array`):
358+
prompt_ids (`jnp.ndarray`):
359359
The prompt or prompts to guide image generation.
360-
image (`jnp.array`):
360+
image (`jnp.ndarray`):
361361
Array representing an image batch to be used as the starting point.
362362
params (`Dict` or `FrozenDict`):
363363
Dictionary containing the model parameters/weights.
@@ -379,7 +379,7 @@ def __call__(
379379
guidance_scale (`float`, *optional*, defaults to 7.5):
380380
A higher guidance scale value encourages the model to generate images closely linked to the text
381381
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
382-
noise (`jnp.array`, *optional*):
382+
noise (`jnp.ndarray`, *optional*):
383383
Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image
384384
generation. Can be used to tweak the same generation with different prompts. The array is generated by
385385
sampling using the supplied random `generator`.

pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,17 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):
266266

267267
def _generate(
268268
self,
269-
prompt_ids: jnp.array,
270-
mask: jnp.array,
271-
masked_image: jnp.array,
269+
prompt_ids: jnp.ndarray,
270+
mask: jnp.ndarray,
271+
masked_image: jnp.ndarray,
272272
params: Union[Dict, FrozenDict],
273273
prng_seed: jax.Array,
274274
num_inference_steps: int,
275275
height: int,
276276
width: int,
277277
guidance_scale: float,
278-
latents: Optional[jnp.array] = None,
279-
neg_prompt_ids: Optional[jnp.array] = None,
278+
latents: Optional[jnp.ndarray] = None,
279+
neg_prompt_ids: Optional[jnp.ndarray] = None,
280280
):
281281
if height % 8 != 0 or width % 8 != 0:
282282
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -394,17 +394,17 @@ def loop_body(step, args):
394394
@replace_example_docstring(EXAMPLE_DOC_STRING)
395395
def __call__(
396396
self,
397-
prompt_ids: jnp.array,
398-
mask: jnp.array,
399-
masked_image: jnp.array,
397+
prompt_ids: jnp.ndarray,
398+
mask: jnp.ndarray,
399+
masked_image: jnp.ndarray,
400400
params: Union[Dict, FrozenDict],
401401
prng_seed: jax.Array,
402402
num_inference_steps: int = 50,
403403
height: Optional[int] = None,
404404
width: Optional[int] = None,
405-
guidance_scale: Union[float, jnp.array] = 7.5,
406-
latents: jnp.array = None,
407-
neg_prompt_ids: jnp.array = None,
405+
guidance_scale: Union[float, jnp.ndarray] = 7.5,
406+
latents: jnp.ndarray = None,
407+
neg_prompt_ids: jnp.ndarray = None,
408408
return_dict: bool = True,
409409
jit: bool = False,
410410
):
@@ -424,7 +424,7 @@ def __call__(
424424
guidance_scale (`float`, *optional*, defaults to 7.5):
425425
A higher guidance scale value encourages the model to generate images closely linked to the text
426426
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
427-
latents (`jnp.array`, *optional*):
427+
latents (`jnp.ndarray`, *optional*):
428428
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
429429
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
430430
array is generated by sampling using the supplied random `generator`.

0 commit comments

Comments
 (0)