@@ -55,7 +55,7 @@ def range_arr(x):
5555
5656
5757@numba_funcify .register (Scan )
58- def numba_funcify_Scan (op , node , ** kwargs ):
58+ def numba_funcify_Scan (op : Scan , node , ** kwargs ):
5959 # Apply inner rewrites
6060 # TODO: Not sure this is the right place to do this, should we have a rewrite that
6161 # explicitly triggers the optimization of the inner graphs of Scan?
@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
6767 .optimizer
6868 )
6969 fgraph = op .fgraph
70+ # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
71+ # We must always discard the oldest tap, so it's safe to destroy it in the inner function.
72+ # TODO: Allow inplace for MITMOT
73+ destroyable_sitsot = [
74+ inner_sitsot
75+ for outer_sitsot , inner_sitsot in zip (
76+ op .outer_sitsot (node .inputs ), op .inner_sitsot (fgraph .inputs ), strict = True
77+ )
78+ if outer_sitsot .type .shape [0 ] == 1
79+ ]
80+ destroyable_mitsot = [
81+ oldest_inner_mitmot
82+ for outer_mitsot , oldest_inner_mitmot , taps in zip (
83+ op .outer_mitsot (node .inputs ),
84+ op .oldest_inner_mitsot (fgraph .inputs ),
85+ op .info .mit_sot_in_slices ,
86+ strict = True ,
87+ )
88+ if outer_mitsot .type .shape [0 ] == abs (min (taps ))
89+ ]
90+ destroyable = {* destroyable_sitsot , * destroyable_mitsot }
7091 add_supervisor_to_fgraph (
7192 fgraph = fgraph ,
72- input_specs = [In (x , borrow = True , mutable = False ) for x in fgraph .inputs ],
93+ input_specs = [
94+ In (x , borrow = True , mutable = x in destroyable ) for x in fgraph .inputs
95+ ],
7396 accept_inplace = True ,
7497 )
7598 rewriter (fgraph )
0 commit comments