@@ -339,39 +339,6 @@ def power_step(prior_result, x):
339339 compare_numba_and_py ([A ], result , test_input_vals )
340340
341341
342- @pytest .mark .parametrize ("n_steps_val" , [1 , 5 ])
343- def test_scan_save_mem_basic (n_steps_val ):
344- """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
345-
346- def f_pow2 (x_tm2 , x_tm1 ):
347- return 2 * x_tm1 + x_tm2
348-
349- init_x = pt .dvector ("init_x" )
350- n_steps = pt .iscalar ("n_steps" )
351- output , _ = scan (
352- f_pow2 ,
353- sequences = [],
354- outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
355- non_sequences = [],
356- n_steps = n_steps ,
357- )
358-
359- state_val = np .array ([1.0 , 2.0 ])
360-
361- numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
362- py_mode = Mode ("py" ).including ("scan_save_mem" )
363-
364- test_input_vals = (state_val , n_steps_val )
365-
366- compare_numba_and_py (
367- [init_x , n_steps ],
368- [output ],
369- test_input_vals ,
370- numba_mode = numba_mode ,
371- py_mode = py_mode ,
372- )
373-
374-
375342def test_grad_sitsot ():
376343 def get_sum_of_grad (inp ):
377344 scan_outputs , updates = scan (
@@ -482,3 +449,120 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
482449 np .testing .assert_array_almost_equal (numba_r , ref_r )
483450
484451 benchmark (numba_fn , * test .values ())
452+
453+
454+ @pytest .mark .parametrize (
455+ "buffer_size" , ("unit" , "aligned" , "misaligned" , "whole" , "whole+init" )
456+ )
457+ @pytest .mark .parametrize ("n_steps, op_size" , [(10 , 2 ), (512 , 2 ), (512 , 256 )])
458+ class TestScanSITSOTBuffer :
459+ def buffer_tester (self , n_steps , op_size , buffer_size , benchmark = None ):
460+ x0 = pt .vector (shape = (op_size ,), dtype = "float64" )
461+ xs , _ = pytensor .scan (
462+ fn = lambda xtm1 : (xtm1 + 1 ),
463+ outputs_info = [x0 ],
464+ n_steps = n_steps - 1 , # 1- makes it easier to align/misalign
465+ )
466+ if buffer_size == "unit" :
467+ xs_kept = xs [- 1 ] # Only last state is used
468+ expected_buffer_size = 2
469+ elif buffer_size == "aligned" :
470+ xs_kept = xs [- 2 :] # The buffer will be aligned at the end of the 9 steps
471+ expected_buffer_size = 2
472+ elif buffer_size == "misaligned" :
473+ xs_kept = xs [- 3 :] # The buffer will be misaligned at the end of the 9 steps
474+ expected_buffer_size = 3
475+ elif buffer_size == "whole" :
476+ xs_kept = xs # What users think is the whole buffer
477+ expected_buffer_size = n_steps - 1
478+ elif buffer_size == "whole+init" :
479+ xs_kept = xs .owner .inputs [0 ] # Whole buffer actually used by Scan
480+ expected_buffer_size = n_steps
481+
482+ x_test = np .zeros (x0 .type .shape )
483+ numba_fn , _ = compare_numba_and_py (
484+ [x0 ],
485+ [xs_kept ],
486+ test_inputs = [x_test ],
487+ numba_mode = "NUMBA" , # Default doesn't include optimizations
488+ eval_obj_mode = False ,
489+ )
490+ [scan_node ] = [
491+ node
492+ for node in numba_fn .maker .fgraph .toposort ()
493+ if isinstance (node .op , Scan )
494+ ]
495+ buffer = scan_node .inputs [1 ]
496+ assert buffer .type .shape [0 ] == expected_buffer_size
497+
498+ if benchmark is not None :
499+ numba_fn .trust_input = True
500+ benchmark (numba_fn , x_test )
501+
502+ def test_sit_sot_buffer (self , n_steps , op_size , buffer_size ):
503+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = None )
504+
505+ def test_sit_sot_buffer_benchmark (self , n_steps , op_size , buffer_size , benchmark ):
506+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = benchmark )
507+
508+
509+ @pytest .mark .parametrize ("constant_n_steps" , [False , True ])
510+ @pytest .mark .parametrize ("n_steps_val" , [1 , 1000 ])
511+ class TestScanMITSOTBuffer :
512+ def buffer_tester (self , constant_n_steps , n_steps_val , benchmark = None ):
513+ """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
514+
515+ def f_pow2 (x_tm2 , x_tm1 ):
516+ return 2 * x_tm1 + x_tm2
517+
518+ init_x = pt .vector ("init_x" , shape = (2 ,))
519+ n_steps = pt .iscalar ("n_steps" )
520+ output , _ = scan (
521+ f_pow2 ,
522+ sequences = [],
523+ outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
524+ non_sequences = [],
525+ n_steps = n_steps_val if constant_n_steps else n_steps ,
526+ )
527+
528+ init_x_val = np .array ([1.0 , 2.0 ], dtype = init_x .type .dtype )
529+ test_vals = (
530+ [init_x_val ]
531+ if constant_n_steps
532+ else [init_x_val , np .asarray (n_steps_val , dtype = n_steps .type .dtype )]
533+ )
534+ numba_fn , _ = compare_numba_and_py (
535+ [init_x ] if constant_n_steps else [init_x , n_steps ],
536+ [output [- 1 ]],
537+ test_vals ,
538+ numba_mode = "NUMBA" ,
539+ eval_obj_mode = False ,
540+ )
541+
542+ if n_steps_val == 1 and constant_n_steps :
543+ # There's no Scan in the graph when nsteps=constant(1)
544+ return
545+
546+ # Check the buffer size as been optimized
547+ [scan_node ] = [
548+ node
549+ for node in numba_fn .maker .fgraph .toposort ()
550+ if isinstance (node .op , Scan )
551+ ]
552+ [mitsot_buffer ] = scan_node .op .outer_mitsot (scan_node .inputs )
553+ mitsot_buffer_shape = mitsot_buffer .shape .eval (
554+ {init_x : init_x_val , n_steps : n_steps_val },
555+ accept_inplace = True ,
556+ on_unused_input = "ignore" ,
557+ )
558+ assert tuple (mitsot_buffer_shape ) == (3 ,)
559+
560+ if benchmark is not None :
561+ numba_fn .trust_input = True
562+ benchmark (numba_fn , * test_vals )
563+
564+ def test_mit_sot_buffer (self , constant_n_steps , n_steps_val ):
565+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = None )
566+
567+ def test_mit_sot_buffer_benchmark (self , constant_n_steps , n_steps_val , benchmark ):
568+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = benchmark )
0 commit comments