1+ import  functools 
12from  typing  import  List , Tuple 
23
34import  numpy  as  np 
45
5- from  pytensor  import  Variable , as_symbolic 
6+ from  pytensor  import  Variable , as_symbolic ,  clone_replace 
67from  pytensor .graph  import  FunctionGraph 
8+ from  pytensor .graph .basic  import  Constant , truncated_graph_inputs 
79from  pytensor .loop .op  import  Scan 
810from  pytensor .scan .utils  import  until 
9- from  pytensor .tensor  import  as_tensor , empty_like 
11+ from  pytensor .tensor  import  as_tensor , constant ,  empty_like ,  minimum 
1012
1113
1214def  scan (
@@ -20,6 +22,8 @@ def scan(
2022    if  sequences  is  None  and  n_steps  is  None :
2123        raise  ValueError ("Must provide n_steps when scanning without sequences" )
2224
25+     # TODO: init_states should be made opaque to the inner function, 
26+     #  since any relationship to the outer graph no longer holds 
2327    if  init_states  is  None :
2428        init_states  =  []
2529    else :
@@ -34,20 +38,31 @@ def scan(
3438            sequences  =  [sequences ]
3539        sequences  =  [as_tensor (s ) for  s  in  sequences ]
3640
41+     if  sequences :
42+         leading_dims  =  [seq .shape [0 ] for  seq  in  sequences ]
43+         shortest_dim  =  functools .reduce (minimum , leading_dims )
44+         if  n_steps  is  None :
45+             n_steps  =  shortest_dim 
46+         else :
47+             n_steps  =  minimum (n_steps , shortest_dim )
48+ 
3749    if  non_sequences  is  None :
3850        non_sequences  =  []
3951    else :
4052        if  not  isinstance (non_sequences , (tuple , list )):
4153            non_sequences  =  [non_sequences ]
4254        non_sequences  =  [as_symbolic (n ) for  n  in  non_sequences ]
4355
56+     # Create subsequence inputs for the inner function 
57+     idx  =  constant (0 , dtype = "int64" , name = "idx" )
58+     symbolic_idx  =  idx .type (name = "idx" )
59+     subsequences  =  [s [symbolic_idx ] for  s  in  sequences ]
4460    # Note: Old scan order is sequences + init + non_sequences 
45-     inner_sequences  =  [s [0 ] for  s  in  sequences ]
46-     inner_inputs  =  [i .type () for  i  in  init_states  +  inner_sequences  +  non_sequences ]
47-     inner_outputs  =  fn (* inner_inputs )
48-     if  not  isinstance (inner_outputs , (tuple , list )):
49-         inner_outputs  =  [inner_outputs ]
50-     next_states  =  [out  for  out  in  inner_outputs  if  not  isinstance (out , until )]
61+     fn_inputs  =  init_states  +  subsequences  +  non_sequences 
62+     fn_outputs  =  fn (* fn_inputs )
63+     if  not  isinstance (fn_outputs , (tuple , list )):
64+         fn_outputs  =  [fn_outputs ]
65+     next_states  =  [out  for  out  in  fn_outputs  if  not  isinstance (out , until )]
5166
5267    if  len (next_states ) >  len (init_states ):
5368        if  not  init_states :
@@ -61,27 +76,45 @@ def scan(
6176    prev_states  =  []
6277    for  i , (init_state , next_state ) in  enumerate (zip (init_states , next_states )):
6378        if  init_state  is  None :
79+             # next_state may reference idx, let's replace that by the initial value 
80+             [next_state ] =  clone_replace (
81+                 output = [next_state ], replace = {symbolic_idx : idx }
82+             )
6483            init_state  =  empty_like (next_state )
65-             init_state .name  =  "empty_init_state" 
66-             inner_inputs .insert (i , init_state .type ())
84+             init_state .name  =  (
85+                 "empty_init_state"   # add 1 offset, since idx is the first state 
86+             )
6787        prev_states .append (init_state )
6888
69-     until_condition  =  [out .condition  for  out  in  inner_outputs  if  isinstance (out , until )]
89+     until_condition  =  [out .condition  for  out  in  fn_outputs  if  isinstance (out , until )]
7090    if  not  until_condition :
7191        until_condition  =  [as_tensor (np .array (True ))]
7292    if  len (until_condition ) >  1 :
7393        raise  ValueError ("Only one until condition can be returned" )
7494
75-     update_fg  =  FunctionGraph (
76-         inputs = inner_inputs , outputs = until_condition  +  next_states 
95+     fgraph_inputs  =  [symbolic_idx ] +  prev_states  +  sequences  +  non_sequences 
96+     fgraph_outputs  =  until_condition  +  [symbolic_idx  +  1 ] +  next_states 
97+ 
98+     all_fgraph_inputs  =  truncated_graph_inputs (
99+         fgraph_outputs , ancestors_to_include = fgraph_inputs 
100+     )
101+     extra_fgraph_inputs  =  [
102+         inp 
103+         for  inp  in  all_fgraph_inputs 
104+         if  (not  isinstance (inp , Constant ) and  inp  not  in fgraph_inputs )
105+     ]
106+     fgraph_inputs  =  fgraph_inputs  +  extra_fgraph_inputs 
107+     update_fg  =  FunctionGraph (inputs = fgraph_inputs , outputs = fgraph_outputs )
108+ 
109+     scan_op  =  Scan (update_fg = update_fg )
110+     scan_outs  =  scan_op (
111+         n_steps , idx , * prev_states , * sequences , * non_sequences , * extra_fgraph_inputs 
77112    )
78-     scan_op  =  Scan (update_fg = update_fg , n_sequences = len (sequences ))
79-     scan_outs  =  scan_op (n_steps , * prev_states , * sequences , * non_sequences )
80113    assert  isinstance (scan_outs , list )
81114    last_states  =  scan_outs [: scan_op .n_states ]
82115    traces  =  scan_outs [scan_op .n_states  :]
83- 
84-     return  last_states , traces 
116+      # Don't return the inner index state 
117+     return  last_states [ 1 :] , traces [ 1 :] 
85118
86119
87120def  map (
0 commit comments