From 864ea0a324fcda5ddff7f79e3beb2a3ef8e11350 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 6 Mar 2023 10:24:50 +0100 Subject: [PATCH 1/3] Numba scan: fix typo --- pytensor/link/numba/dispatch/scan.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 71c38d04e0..6926759285 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -367,8 +367,6 @@ def scan({", ".join(outer_in_names)}): } global_env["np"] = np - scalar_op_fn = compile_function_src( - scan_op_src, "scan", {**globals(), **global_env} - ) + scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) - return numba_basic.numba_njit(scalar_op_fn) + return numba_basic.numba_njit(scan_op_fn) From c1ebf7675d197249c36e2cc31f8936212a1c8ffa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 6 Mar 2023 10:20:25 +0100 Subject: [PATCH 2/3] Numba scan: only rotate outputs if indexing wraps around storage size --- pytensor/link/numba/dispatch/scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 6926759285..aa7d5ff331 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -190,8 +190,8 @@ def add_output_storage_post_proc_stmt( output_storage_post_proc_stmts.append( dedent( f""" - {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) - if {outer_in_name}_shift > 0: + if (i + {tap_size}) > {storage_size}: + {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) From 36b05f7f591c2a20c90cd10802c1ed4584377387 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 6 Mar 2023 11:28:52 +0100 Subject: [PATCH 3/3] Numba scan: reuse scalar arrays for taps from vector inputs Indexing vector inputs to create taps during scan, yields numeric variables which must be wrapped again into scalar arrays before passing into the inernal function. This commit pre-allocates such arrays and reuses them during looping. --- pytensor/link/numba/dispatch/scan.py | 55 ++++++++++++++++++++++------ tests/link/numba/test_scan.py | 54 ++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 12 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index aa7d5ff331..a3d6e6da9d 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs): # Inner-inputs are ordered as follows: # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # shared-inputs + non-sequences. + temp_scalar_storage_alloc_stmts: List[str] = [] + inner_in_exprs_scalar: List[str] = [] inner_in_exprs: List[str] = [] def add_inner_in_expr( - outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str] + outer_in_name: str, + tap_offset: Optional[int], + storage_size_var: Optional[str], + vector_slice_opt: bool, ): """Construct an inner-input expression.""" storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) - indexed_inner_in_str = ( - storage_name - if tap_offset is None - else idx_to_str( - storage_name, tap_offset, size=storage_size_var, allow_scalar=False + if vector_slice_opt: + indexed_inner_in_str_scalar = idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=True + ) + temp_storage = f"{storage_name}_temp_scalar_{tap_offset}" + storage_dtype = outer_in_var.type.numpy_dtype.name + temp_scalar_storage_alloc_stmts.append( + f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})" + ) + inner_in_exprs_scalar.append( + f"{temp_storage}[()] = {indexed_inner_in_str_scalar}" + ) + indexed_inner_in_str = temp_storage + else: + indexed_inner_in_str = ( + storage_name + if tap_offset is None + else idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=False + ) ) - ) inner_in_exprs.append(indexed_inner_in_str) for outer_in_name in outer_in_seqs_names: # These outer-inputs are indexed without offsets or storage wrap-around - add_inner_in_expr(outer_in_name, 0, None) + outer_in_var = outer_in_names_to_vars[outer_in_name] + is_vector = outer_in_var.ndim == 1 + add_inner_in_expr(outer_in_name, 0, None, vector_slice_opt=is_vector) inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict( zip( @@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt( for in_tap in input_taps: tap_offset = in_tap + tap_storage_size assert tap_offset >= 0 - add_inner_in_expr(outer_in_name, tap_offset, storage_size_name) + is_vector = outer_in_var.ndim == 1 + add_inner_in_expr( + outer_in_name, + tap_offset, + storage_size_name, + vector_slice_opt=is_vector, + ) output_taps = inner_in_names_to_output_taps.get( outer_in_name, [tap_storage_size] @@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt( else: storage_size_stmt = "" - add_inner_in_expr(outer_in_name, None, None) + add_inner_in_expr(outer_in_name, None, None, vector_slice_opt=False) inner_out_to_outer_in_stmts.append(storage_name) output_idx = outer_output_names.index(storage_name) @@ -325,7 +352,7 @@ def add_output_storage_post_proc_stmt( ) for name in outer_in_non_seqs_names: - add_inner_in_expr(name, None, None) + add_inner_in_expr(name, None, None, vector_slice_opt=False) if op.info.as_while: # The inner function will return a boolean as the last value @@ -333,9 +360,11 @@ def add_output_storage_post_proc_stmt( assert len(inner_in_exprs) == len(op.fgraph.inputs) + inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar) inner_in_args = create_arg_string(inner_in_exprs) inner_outputs = create_tuple_string(inner_output_names) input_storage_block = "\n".join(storage_alloc_stmts) + input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts) output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) @@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}): {indent(input_storage_block, " " * 4)} +{indent(input_temp_scalar_storage_block, " " * 4)} + i = 0 cond = np.array(False) while i < n_steps and not cond.item(): +{indent(inner_scalar_in_args_to_temp_storage, " " * 8)} + {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)} diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index be0f64e02c..2481fc9a12 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -1,6 +1,7 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as at from pytensor import config, function, grad from pytensor.compile.mode import Mode, get_mode @@ -9,7 +10,7 @@ from pytensor.scan.basic import scan from pytensor.scan.op import Scan from pytensor.scan.utils import until -from pytensor.tensor import log, vector +from pytensor.tensor import log, scalar, vector from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import RandomStream from tests import unittest_tools as utt @@ -442,3 +443,54 @@ def test_inner_graph_optimized(): assert isinstance(inner_scan_node.op, Elemwise) and isinstance( inner_scan_node.op.scalar_op, Log1p ) + + +def test_vector_taps_benchmark(benchmark): + """Test vector taps performance. + + Vector taps get indexed into numeric types, that must be wrapped back into + scalar arrays. The numba Scan implementation has an optimization to reuse + these scalar arrays instead of allocating them in every iteration. + """ + n_steps = 1000 + + seq1 = vector("seq1", dtype="float64", shape=(n_steps,)) + seq2 = vector("seq2", dtype="float64", shape=(n_steps,)) + mitsot_init = vector("mitsot_init", dtype="float64", shape=(2,)) + sitsot_init = scalar("sitsot_init", dtype="float64") + + def step(seq1, seq2, mitsot1, mitsot2, sitsot1): + mitsot3 = mitsot1 + seq2 + mitsot2 + seq1 + sitsot2 = sitsot1 + mitsot3 + return mitsot3, sitsot2 + + outs, _ = scan( + fn=step, + sequences=[seq1, seq2], + outputs_info=[ + dict(initial=mitsot_init, taps=[-2, -1]), + dict(initial=sitsot_init, taps=[-1]), + ], + ) + + rng = np.random.default_rng(474) + test = { + seq1: rng.normal(size=n_steps), + seq2: rng.normal(size=n_steps), + mitsot_init: rng.normal(size=(2,)), + sitsot_init: rng.normal(), + } + + numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA")) + scan_nodes = [ + node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + assert len(scan_nodes) == 1 + numba_res = numba_fn(*test.values()) + + ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE")) + ref_res = ref_fn(*test.values()) + for numba_r, ref_r in zip(numba_res, ref_res): + np.testing.assert_array_almost_equal(numba_r, ref_r) + + benchmark(numba_fn, *test.values())