@@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs):
112112 # Inner-inputs are ordered as follows:
113113 # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
114114 # shared-inputs + non-sequences.
115+ temp_scalar_storage_alloc_stmts : List [str ] = []
116+ inner_in_exprs_scalar : List [str ] = []
115117 inner_in_exprs : List [str ] = []
116118
117119 def add_inner_in_expr (
118- outer_in_name : str , tap_offset : Optional [int ], storage_size_var : Optional [str ]
120+ outer_in_name : str ,
121+ tap_offset : Optional [int ],
122+ storage_size_var : Optional [str ],
123+ vector_slice_opt : bool ,
119124 ):
120125 """Construct an inner-input expression."""
121126 storage_name = outer_in_to_storage_name .get (outer_in_name , outer_in_name )
122- indexed_inner_in_str = (
123- storage_name
124- if tap_offset is None
125- else idx_to_str (
126- storage_name , tap_offset , size = storage_size_var , allow_scalar = False
127+ if vector_slice_opt :
128+ indexed_inner_in_str_scalar = idx_to_str (
129+ storage_name , tap_offset , size = storage_size_var , allow_scalar = True
130+ )
131+ temp_storage = f"{ storage_name } _temp_scalar_{ tap_offset } "
132+ storage_dtype = outer_in_var .type .numpy_dtype .name
133+ temp_scalar_storage_alloc_stmts .append (
134+ f"{ temp_storage } = np.empty((), dtype=np.{ storage_dtype } )"
135+ )
136+ inner_in_exprs_scalar .append (
137+ f"{ temp_storage } [()] = { indexed_inner_in_str_scalar } "
138+ )
139+ indexed_inner_in_str = temp_storage
140+ else :
141+ indexed_inner_in_str = (
142+ storage_name
143+ if tap_offset is None
144+ else idx_to_str (
145+ storage_name , tap_offset , size = storage_size_var , allow_scalar = False
146+ )
127147 )
128- )
129148 inner_in_exprs .append (indexed_inner_in_str )
130149
131150 for outer_in_name in outer_in_seqs_names :
132151 # These outer-inputs are indexed without offsets or storage wrap-around
133- add_inner_in_expr (outer_in_name , 0 , None )
152+ outer_in_var = outer_in_names_to_vars [outer_in_name ]
153+ is_vector = outer_in_var .ndim == 1
154+ add_inner_in_expr (outer_in_name , 0 , None , vector_slice_opt = is_vector )
134155
135156 inner_in_names_to_input_taps : Dict [str , Tuple [int , ...]] = dict (
136157 zip (
@@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt(
232253 for in_tap in input_taps :
233254 tap_offset = in_tap + tap_storage_size
234255 assert tap_offset >= 0
235- add_inner_in_expr (outer_in_name , tap_offset , storage_size_name )
256+ is_vector = outer_in_var .ndim == 1
257+ add_inner_in_expr (
258+ outer_in_name ,
259+ tap_offset ,
260+ storage_size_name ,
261+ vector_slice_opt = is_vector ,
262+ )
236263
237264 output_taps = inner_in_names_to_output_taps .get (
238265 outer_in_name , [tap_storage_size ]
@@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt(
253280
254281 else :
255282 storage_size_stmt = ""
256- add_inner_in_expr (outer_in_name , None , None )
283+ add_inner_in_expr (outer_in_name , None , None , vector_slice_opt = False )
257284 inner_out_to_outer_in_stmts .append (storage_name )
258285
259286 output_idx = outer_output_names .index (storage_name )
@@ -325,17 +352,19 @@ def add_output_storage_post_proc_stmt(
325352 )
326353
327354 for name in outer_in_non_seqs_names :
328- add_inner_in_expr (name , None , None )
355+ add_inner_in_expr (name , None , None , vector_slice_opt = False )
329356
330357 if op .info .as_while :
331358 # The inner function will return a boolean as the last value
332359 inner_out_to_outer_in_stmts .append ("cond" )
333360
334361 assert len (inner_in_exprs ) == len (op .fgraph .inputs )
335362
363+ inner_scalar_in_args_to_temp_storage = "\n " .join (inner_in_exprs_scalar )
336364 inner_in_args = create_arg_string (inner_in_exprs )
337365 inner_outputs = create_tuple_string (inner_output_names )
338366 input_storage_block = "\n " .join (storage_alloc_stmts )
367+ input_temp_scalar_storage_block = "\n " .join (temp_scalar_storage_alloc_stmts )
339368 output_storage_post_processing_block = "\n " .join (output_storage_post_proc_stmts )
340369 inner_out_post_processing_block = "\n " .join (inner_out_post_processing_stmts )
341370
@@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}):
348377
349378{ indent (input_storage_block , " " * 4 )}
350379
380+ { indent (input_temp_scalar_storage_block , " " * 4 )}
381+
351382 i = 0
352383 cond = np.array(False)
353384 while i < n_steps and not cond.item():
385+ { indent (inner_scalar_in_args_to_temp_storage , " " * 8 )}
386+
354387 { inner_outputs } = scan_inner_func({ inner_in_args } )
355388{ indent (inner_out_post_processing_block , " " * 8 )}
356389{ indent (inner_out_to_outer_out_stmts , " " * 8 )}
0 commit comments