2727from pytensor .graph .replace import clone_replace
2828from pytensor .graph .rewriting .basic import (
2929 GraphRewriter ,
30+ bfs_rewriter ,
3031 copy_stack_trace ,
31- in2out ,
3232 node_rewriter ,
3333)
3434from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2527,15 +2527,15 @@ def scan_push_out_dot1(fgraph, node):
25272527# ScanSaveMem should execute only once per node.
25282528optdb .register (
25292529 "scan_save_mem_prealloc" ,
2530- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2530+ bfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
25312531 "fast_run" ,
25322532 "scan" ,
25332533 "scan_save_mem" ,
25342534 position = 1.61 ,
25352535)
25362536optdb .register (
25372537 "scan_save_mem_no_prealloc" ,
2538- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2538+ bfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
25392539 "numba" ,
25402540 "jax" ,
25412541 "pytorch" ,
@@ -2556,7 +2556,7 @@ def scan_push_out_dot1(fgraph, node):
25562556
25572557scan_seqopt1 .register (
25582558 "scan_remove_constants_and_unused_inputs0" ,
2559- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2559+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
25602560 "remove_constants_and_unused_inputs_scan" ,
25612561 "fast_run" ,
25622562 "scan" ,
@@ -2565,7 +2565,7 @@ def scan_push_out_dot1(fgraph, node):
25652565
25662566scan_seqopt1 .register (
25672567 "scan_push_out_non_seq" ,
2568- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2568+ bfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
25692569 "scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
25702570 "fast_run" ,
25712571 "scan" ,
@@ -2575,7 +2575,7 @@ def scan_push_out_dot1(fgraph, node):
25752575
25762576scan_seqopt1 .register (
25772577 "scan_push_out_seq" ,
2578- in2out (scan_push_out_seq , ignore_newtrees = True ),
2578+ bfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
25792579 "scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
25802580 "fast_run" ,
25812581 "scan" ,
@@ -2586,7 +2586,7 @@ def scan_push_out_dot1(fgraph, node):
25862586
25872587scan_seqopt1 .register (
25882588 "scan_push_out_dot1" ,
2589- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2589+ bfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
25902590 "scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
25912591 "fast_run" ,
25922592 "more_mem" ,
@@ -2599,7 +2599,7 @@ def scan_push_out_dot1(fgraph, node):
25992599scan_seqopt1 .register (
26002600 "scan_push_out_add" ,
26012601 # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2602- in2out (scan_push_out_add , ignore_newtrees = False ),
2602+ bfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
26032603 "scan_pushout_add" , # For backcompat: so it can be tagged with old name
26042604 "fast_run" ,
26052605 "more_mem" ,
@@ -2610,22 +2610,22 @@ def scan_push_out_dot1(fgraph, node):
26102610
26112611scan_eqopt2 .register (
26122612 "while_scan_merge_subtensor_last_element" ,
2613- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2613+ bfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
26142614 "fast_run" ,
26152615 "scan" ,
26162616)
26172617
26182618scan_eqopt2 .register (
26192619 "constant_folding_for_scan2" ,
2620- in2out (constant_folding , ignore_newtrees = True ),
2620+ bfs_rewriter (constant_folding , ignore_newtrees = True ),
26212621 "fast_run" ,
26222622 "scan" ,
26232623)
26242624
26252625
26262626scan_eqopt2 .register (
26272627 "scan_remove_constants_and_unused_inputs1" ,
2628- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2628+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26292629 "remove_constants_and_unused_inputs_scan" ,
26302630 "fast_run" ,
26312631 "scan" ,
@@ -2640,23 +2640,23 @@ def scan_push_out_dot1(fgraph, node):
26402640# After Merge optimization
26412641scan_eqopt2 .register (
26422642 "scan_remove_constants_and_unused_inputs2" ,
2643- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2643+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26442644 "remove_constants_and_unused_inputs_scan" ,
26452645 "fast_run" ,
26462646 "scan" ,
26472647)
26482648
26492649scan_eqopt2 .register (
26502650 "scan_merge_inouts" ,
2651- in2out (scan_merge_inouts , ignore_newtrees = True ),
2651+ bfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
26522652 "fast_run" ,
26532653 "scan" ,
26542654)
26552655
26562656# After everything else
26572657scan_eqopt2 .register (
26582658 "scan_remove_constants_and_unused_inputs3" ,
2659- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2659+ bfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26602660 "remove_constants_and_unused_inputs_scan" ,
26612661 "fast_run" ,
26622662 "scan" ,
0 commit comments