diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 71c38d04e0..a3d6e6da9d 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs): # Inner-inputs are ordered as follows: # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # shared-inputs + non-sequences. + temp_scalar_storage_alloc_stmts: List[str] = [] + inner_in_exprs_scalar: List[str] = [] inner_in_exprs: List[str] = [] def add_inner_in_expr( - outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str] + outer_in_name: str, + tap_offset: Optional[int], + storage_size_var: Optional[str], + vector_slice_opt: bool, ): """Construct an inner-input expression.""" storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) - indexed_inner_in_str = ( - storage_name - if tap_offset is None - else idx_to_str( - storage_name, tap_offset, size=storage_size_var, allow_scalar=False + if vector_slice_opt: + indexed_inner_in_str_scalar = idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=True + ) + temp_storage = f"{storage_name}_temp_scalar_{tap_offset}" + storage_dtype = outer_in_var.type.numpy_dtype.name + temp_scalar_storage_alloc_stmts.append( + f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})" + ) + inner_in_exprs_scalar.append( + f"{temp_storage}[()] = {indexed_inner_in_str_scalar}" + ) + indexed_inner_in_str = temp_storage + else: + indexed_inner_in_str = ( + storage_name + if tap_offset is None + else idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=False + ) ) - ) inner_in_exprs.append(indexed_inner_in_str) for outer_in_name in outer_in_seqs_names: # These outer-inputs are indexed without offsets or storage wrap-around - add_inner_in_expr(outer_in_name, 0, None) + outer_in_var = outer_in_names_to_vars[outer_in_name] + is_vector = outer_in_var.ndim == 1 + add_inner_in_expr(outer_in_name, 0, None, vector_slice_opt=is_vector) inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict( zip( @@ -190,8 +211,8 @@ def add_output_storage_post_proc_stmt( output_storage_post_proc_stmts.append( dedent( f""" - {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) - if {outer_in_name}_shift > 0: + if (i + {tap_size}) > {storage_size}: + {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) @@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt( for in_tap in input_taps: tap_offset = in_tap + tap_storage_size assert tap_offset >= 0 - add_inner_in_expr(outer_in_name, tap_offset, storage_size_name) + is_vector = outer_in_var.ndim == 1 + add_inner_in_expr( + outer_in_name, + tap_offset, + storage_size_name, + vector_slice_opt=is_vector, + ) output_taps = inner_in_names_to_output_taps.get( outer_in_name, [tap_storage_size] @@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt( else: storage_size_stmt = "" - add_inner_in_expr(outer_in_name, None, None) + add_inner_in_expr(outer_in_name, None, None, vector_slice_opt=False) inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) @@ -325,7 +352,7 @@ def add_output_storage_post_proc_stmt( ) for name in outer_in_non_seqs_names: - add_inner_in_expr(name, None, None) + add_inner_in_expr(name, None, None, vector_slice_opt=False) if op.info.as_while: # The inner function will return a boolean as the last value @@ -333,9 +360,11 @@ def add_output_storage_post_proc_stmt( assert len(inner_in_exprs) == len(op.fgraph.inputs) + inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar) inner_in_args = create_arg_string(inner_in_exprs) inner_outputs = create_tuple_string(inner_output_names) input_storage_block = "\n".join(storage_alloc_stmts) + input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts) output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) @@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}): {indent(input_storage_block, " " * 4)} +{indent(input_temp_scalar_storage_block, " " * 4)} + i = 0 cond = np.array(False) while i < n_steps and not cond.item(): +{indent(inner_scalar_in_args_to_temp_storage, " " * 8)} + {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)} @@ -367,8 +400,6 @@ def scan({", ".join(outer_in_names)}): } global_env["np"] = np - scalar_op_fn = compile_function_src( - scan_op_src, "scan", {**globals(), **global_env} - ) + scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) - return numba_basic.numba_njit(scalar_op_fn) + return numba_basic.numba_njit(scan_op_fn) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index be0f64e02c..2481fc9a12 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -1,6 +1,7 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as at from pytensor import config, function, grad from pytensor.compile.mode import Mode, get_mode @@ -9,7 +10,7 @@ from pytensor.scan.basic import scan from pytensor.scan.op import Scan from pytensor.scan.utils import until -from pytensor.tensor import log, vector +from pytensor.tensor import log, scalar, vector from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import RandomStream from tests import unittest_tools as utt @@ -442,3 +443,54 @@ def test_inner_graph_optimized(): assert isinstance(inner_scan_node.op, Elemwise) and isinstance( inner_scan_node.op.scalar_op, Log1p ) + + +def test_vector_taps_benchmark(benchmark): + """Test vector taps performance. + + Vector taps get indexed into numeric types, that must be wrapped back into + scalar arrays. The numba Scan implementation has an optimization to reuse + these scalar arrays instead of allocating them in every iteration. + """ + n_steps = 1000 + + seq1 = vector("seq1", dtype="float64", shape=(n_steps,)) + seq2 = vector("seq2", dtype="float64", shape=(n_steps,)) + mitsot_init = vector("mitsot_init", dtype="float64", shape=(2,)) + sitsot_init = scalar("sitsot_init", dtype="float64") + + def step(seq1, seq2, mitsot1, mitsot2, sitsot1): + mitsot3 = mitsot1 + seq2 + mitsot2 + seq1 + sitsot2 = sitsot1 + mitsot3 + return mitsot3, sitsot2 + + outs, _ = scan( + fn=step, + sequences=[seq1, seq2], + outputs_info=[ + dict(initial=mitsot_init, taps=[-2, -1]), + dict(initial=sitsot_init, taps=[-1]), + ], + ) + + rng = np.random.default_rng(474) + test = { + seq1: rng.normal(size=n_steps), + seq2: rng.normal(size=n_steps), + mitsot_init: rng.normal(size=(2,)), + sitsot_init: rng.normal(), + } + + numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA")) + scan_nodes = [ + node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + assert len(scan_nodes) == 1 + numba_res = numba_fn(*test.values()) + + ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE")) + ref_res = ref_fn(*test.values()) + for numba_r, ref_r in zip(numba_res, ref_res): + np.testing.assert_array_almost_equal(numba_r, ref_r) + + benchmark(numba_fn, *test.values())