7070 get_slice_elements ,
7171 set_subtensor ,
7272)
73- from pytensor .tensor .variable import TensorConstant
73+ from pytensor .tensor .variable import TensorConstant , TensorVariable
7474
7575
7676list_opt_slice = [
@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11821182 return subtensor_merge_replacements
11831183
11841184
1185- @node_rewriter ([Scan ])
1186- def scan_save_mem (fgraph , node ):
1185+ def scan_save_mem_rewrite (fgraph , node , backend_supports_output_pre_allocation : bool ):
11871186 r"""Graph optimizer that reduces scan memory consumption.
11881187
11891188 This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
12141213
12151214 The scan perform implementation takes the output sizes into consideration,
12161215 saving the newest results over the oldest ones whenever the buffer is filled.
1217- """
1218- if not isinstance (node .op , Scan ):
1219- return False
12201216
1217+ Paramaters
1218+ ----------
1219+ backend_supports_output_pre_allocation: bool
1220+ When the backend supports output pre-allocation Scan must keep buffers
1221+ with a length of required_states + 1, because the inner function will
1222+ attempt to write the inner function outputs directly into the provided
1223+ position in the outer circular buffer. This would invalidate results,
1224+ if the input is still needed for some other output computation.
1225+ """
12211226 if hasattr (fgraph , "shape_feature" ):
12221227 shape_of = fgraph .shape_feature .shape_of
12231228 else :
@@ -1270,14 +1275,15 @@ def scan_save_mem(fgraph, node):
12701275 # Note: For simplicity while Scans also have global_nsteps set to None.
12711276 # All step optimizations require knowing the shape of the output, which
12721277 # cannot be determined from the inputs alone.
1278+ global_nsteps : None | dict
12731279 assert len (node .outputs ) >= c_outs
12741280 if len (node .outputs ) == c_outs and not op .info .as_while :
12751281 global_nsteps = {"real" : - 1 , "sym" : []}
12761282 else :
12771283 global_nsteps = None
12781284
12791285 # Keeps track of the original slices that each client represent
1280- slices = [None for o in node .outputs ]
1286+ slices : list [ None | list ] = [None for o in node .outputs ]
12811287
12821288 # A list for each output indicating how many intermediate values
12831289 # should be stored. If negative it means none of the intermediate
@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
12941300 # or not
12951301 flag_store = False
12961302
1297- # 2.2 Loop over the clients
1303+ # 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
12981304 for i , out in enumerate (node .outputs [:c_outs ]):
12991305 # look at all its clients
13001306 slices [i ] = []
@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
13371343 except KeyError :
13381344 length = out .shape [0 ]
13391345 cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1340- slices [i ] += [(cf_slice , this_slice )]
1346+ slices [i ] += [(cf_slice , this_slice )] # type: ignore
13411347
13421348 if isinstance (this_slice [0 ], slice ) and this_slice [0 ].stop is None :
13431349 global_nsteps = None
@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
14771483 # for mitsots and sitsots (because mitmots are not
14781484 # currently supported by the mechanism) and only if
14791485 # the pre-allocation mechanism is activated.
1480- prealloc_outs = config .scan__allow_output_prealloc
1486+ prealloc_outs = (
1487+ backend_supports_output_pre_allocation
1488+ and config .scan__allow_output_prealloc
1489+ )
14811490
14821491 first_mitsot_idx = op_info .n_mit_mot
14831492 last_sitsot_idx = (
@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
14861495 preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
14871496
14881497 if prealloc_outs and preallocable_output :
1498+ # TODO: If there's only one output or other outputs do not depend
1499+ # on the same input, we could reduce the buffer size to the minimum
14891500 pval = select_max (nw_steps - start + init_l [i ], init_l [i ] + 1 )
14901501 else :
14911502 pval = select_max (nw_steps - start + init_l [i ], init_l [i ])
@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
16521663 name = op .name ,
16531664 allow_gc = op .allow_gc ,
16541665 )
1655- new_outs = new_op (* node_ins , return_list = True )
1666+ new_outs = cast ( list [ TensorVariable ], new_op (* node_ins , return_list = True ) )
16561667
16571668 old_new = []
16581669 # 3.7 Get replace pairs for those outputs that do not change
@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
16821693 sl_ins = get_slice_elements (
16831694 nw_slice , lambda entry : isinstance (entry , Variable )
16841695 )
1685- new_o = subtens (new_outs [nw_pos ], * sl_ins )
1696+ new_o = cast ( TensorVariable , subtens (new_outs [nw_pos ], * sl_ins ) )
16861697 if new_o .ndim > 0 :
16871698 new_o = new_o [:: cnf_slice [1 ]]
16881699 replaced_outs .append (idx )
@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
17371748 sl_ins = get_slice_elements (
17381749 nw_slice , lambda entry : isinstance (entry , Variable )
17391750 )
1740- new_o = subtens (new_outs [nw_pos ], * sl_ins )
1751+ new_o = cast ( TensorVariable , subtens (new_outs [nw_pos ], * sl_ins ) )
17411752 if new_o .ndim > 0 :
17421753 new_o = new_o [:: cnf_slice [1 ]]
17431754 old_new += [(old , new_o )]
@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
17681779 return False
17691780
17701781
1782+ @node_rewriter ([Scan ])
1783+ def scan_save_mem_prealloc (fgraph , node ):
1784+ return scan_save_mem_rewrite (
1785+ fgraph , node , backend_supports_output_pre_allocation = True
1786+ )
1787+
1788+
1789+ @node_rewriter ([Scan ])
1790+ def scan_save_mem_no_prealloc (fgraph , node ):
1791+ return scan_save_mem_rewrite (
1792+ fgraph , node , backend_supports_output_pre_allocation = False
1793+ )
1794+
1795+
17711796class ScanMerge (GraphRewriter ):
17721797 r"""Graph optimizer that merges different scan ops.
17731798
@@ -2495,10 +2520,20 @@ def scan_push_out_dot1(fgraph, node):
24952520optdb .register ("scan_eqopt2" , scan_eqopt2 , "fast_run" , "scan" , position = 1.6 )
24962521# ScanSaveMem should execute only once per node.
24972522optdb .register (
2498- "scan_save_mem " ,
2499- in2out (scan_save_mem , ignore_newtrees = True ),
2523+ "scan_save_mem_prealloc " ,
2524+ in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
25002525 "fast_run" ,
25012526 "scan" ,
2527+ "scan_save_mem" ,
2528+ position = 1.61 ,
2529+ )
2530+ optdb .register (
2531+ "scan_save_mem_no_prealloc" ,
2532+ in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2533+ "numba" ,
2534+ "jax" ,
2535+ "pytorch" ,
2536+ use_db_name_as_tag = False ,
25022537 position = 1.61 ,
25032538)
25042539optdb .register (
0 commit comments