1313from pytensor .scan .op import Scan
1414from pytensor .tensor import random
1515from pytensor .tensor .math import gammaln , log
16- from pytensor .tensor .type import lscalar , scalar , vector , dmatrix , dvector
16+ from pytensor .tensor .type import dmatrix , dvector , lscalar , scalar , vector
1717from tests .link .jax .test_basic import compare_jax_and_py
1818
1919
20- # jax = pytest.importorskip("jax")
21-
22- import jax
23- jax .config .update ('jax_platform_name' , 'cpu' )
20+ jax = pytest .importorskip ("jax" )
2421
2522
2623@pytest .mark .parametrize ("view" , [None , (- 1 ,), slice (- 2 , None , None )])
@@ -322,22 +319,24 @@ def input_step_fn(y_tm1, y_tm3, a):
322319 compare_jax_and_py (out_fg , test_input_vals )
323320
324321
325- @pytest .mark .parametrize (' x0_func' , [dvector , dmatrix ])
326- @pytest .mark .parametrize (' A_func' , [dmatrix , dmatrix ])
322+ @pytest .mark .parametrize (" x0_func" , [dvector , dmatrix ])
323+ @pytest .mark .parametrize (" A_func" , [dmatrix , dmatrix ])
327324def test_nd_scan_sit_sot (x0_func , A_func ):
328- x0 = x0_func ('x0' )
329- A = A_func ('A' )
325+ x0 = x0_func ("x0" )
326+ A = A_func ("A" )
330327
331328 n_steps = 3
332329 k = 3
333330
334331 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
335- xs , _ = scan (lambda X , A : A @ X ,
336- non_sequences = [A ],
337- outputs_info = [x0 ],
338- n_steps = n_steps ,
339- mode = get_mode ('JAX' ))
340-
332+ xs , _ = scan (
333+ lambda X , A : A @ X ,
334+ non_sequences = [A ],
335+ outputs_info = [x0 ],
336+ n_steps = n_steps ,
337+ mode = get_mode ("JAX" ),
338+ )
339+
341340 x0_val = np .arange (k ) if x0 .ndim == 1 else np .diag (np .arange (k ))
342341 A_val = np .eye (k )
343342
@@ -347,19 +346,21 @@ def test_nd_scan_sit_sot(x0_func, A_func):
347346
348347
349348def test_nd_scan_sit_sot_with_seq ():
350- x = dmatrix ('x0' )
351- A = dmatrix ('A' )
349+ x = dmatrix ("x0" )
350+ A = dmatrix ("A" )
352351
353352 n_steps = 3
354353 k = 3
355354
356355 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
357- xs , _ = scan (lambda X , A : A @ X ,
358- non_sequences = [A ],
359- sequences = [x ],
360- n_steps = n_steps ,
361- mode = get_mode ('JAX' ))
362-
356+ xs , _ = scan (
357+ lambda X , A : A @ X ,
358+ non_sequences = [A ],
359+ sequences = [x ],
360+ n_steps = n_steps ,
361+ mode = get_mode ("JAX" ),
362+ )
363+
363364 x_val = np .tile (np .arange (k ), n_steps ).reshape (n_steps , k )
364365 A_val = np .eye (k )
365366
@@ -369,17 +370,17 @@ def test_nd_scan_sit_sot_with_seq():
369370
370371
371372def test_nd_scan_mit_sot ():
372- x0 = dmatrix ('x0' )
373- A = dmatrix ('A' )
374- B = dmatrix ('B' )
373+ x0 = dmatrix ("x0" )
374+ A = dmatrix ("A" )
375+ B = dmatrix ("B" )
375376
376377 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
377378 xs , _ = scan (
378379 lambda xtm3 , xtm1 , A , B : A @ xtm3 + B @ xtm1 ,
379380 outputs_info = [{"initial" : x0 , "taps" : [- 3 , - 1 ]}],
380381 non_sequences = [A , B ],
381382 n_steps = 10 ,
382- mode = get_mode (' JAX' )
383+ mode = get_mode (" JAX" ),
383384 )
384385
385386 fg = FunctionGraph ([x0 , A , B ], [xs ])
@@ -392,8 +393,8 @@ def test_nd_scan_mit_sot():
392393
393394
394395def test_nd_scan_sit_sot_with_carry ():
395- x0 = dvector ('x0' )
396- A = dmatrix ('A' )
396+ x0 = dvector ("x0" )
397+ A = dmatrix ("A" )
397398
398399 def step (x , A ):
399400 return A @ x , x .sum ()
@@ -404,7 +405,7 @@ def step(x, A):
404405 outputs_info = [x0 , None ],
405406 non_sequences = [A ],
406407 n_steps = 10 ,
407- mode = get_mode (' JAX' )
408+ mode = get_mode (" JAX" ),
408409 )
409410
410411 fg = FunctionGraph ([x0 , A ], xs )
0 commit comments