From e2e11442b9080c82c822635305c70ab64c60b181 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 12 Jun 2024 12:55:58 +0200 Subject: [PATCH] Support single multidimensional indexing in Numba via rewrites --- pytensor/tensor/rewriting/subtensor.py | 109 +++++++++++++++++++++++++ tests/link/numba/test_subtensor.py | 53 ++++++++++-- 2 files changed, 157 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 32b81ff9a2..8ee86e6021 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -7,6 +7,7 @@ import pytensor import pytensor.scalar.basic as ps from pytensor import compile +from pytensor.compile import optdb from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( WalkingGraphRewriter, @@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs copy_stack_trace(node.outputs, new_out) return new_out + + +@node_rewriter(tracks=[AdvancedSubtensor]) +def ravel_multidimensional_bool_idx(fgraph, node): + """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba + + x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] + """ + x, *idxs = node.inputs + + if any( + isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int") + for idx in idxs + ): + # Get out if there are any other advanced indexes + return None + + bool_idxs = [ + (i, idx) + for i, idx in enumerate(idxs) + if (isinstance(idx.type, TensorType) and idx.dtype == "bool") + ] + + if len(bool_idxs) != 1: + # Get out if there are no or multiple boolean idxs + return None + + [(bool_idx_pos, bool_idx)] = bool_idxs + bool_idx_ndim = bool_idx.type.ndim + if bool_idx.type.ndim < 2: + # No need to do anything if it's a vector or scalar, as it's already supported by Numba + return None + + x_shape = x.shape + raveled_x = x.reshape( + (*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :]) + ) + + raveled_bool_idx = bool_idx.ravel() + new_idxs = list(idxs) + new_idxs[bool_idx_pos] = raveled_bool_idx + + return [raveled_x[tuple(new_idxs)]] + + +@node_rewriter(tracks=[AdvancedSubtensor]) +def ravel_multidimensional_int_idx(fgraph, node): + """Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba + + x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3)) + + + NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices + + x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes + """ + x, *idxs = node.inputs + + if any( + isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool") + for idx in idxs + ): + # Get out if there are any other advanced indexes + return None + + int_idxs = [ + (i, idx) + for i, idx in enumerate(idxs) + if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int")) + ] + + if len(int_idxs) != 1: + # Get out if there are no or multiple integer idxs + return None + + [(int_idx_pos, int_idx)] = int_idxs + if int_idx.type.ndim < 2: + # No need to do anything if it's a vector or scalar, as it's already supported by Numba + return None + + raveled_int_idx = int_idx.ravel() + new_idxs = list(idxs) + new_idxs[int_idx_pos] = raveled_int_idx + raveled_subtensor = x[tuple(new_idxs)] + + # Reshape into correct shape + # Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing + # must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front + raveled_shape = raveled_subtensor.shape + unraveled_shape = ( + *raveled_shape[:int_idx_pos], + *int_idx.shape, + *raveled_shape[int_idx_pos + 1 :], + ) + return [raveled_subtensor.reshape(unraveled_shape)] + + +optdb["specialize"].register( + ravel_multidimensional_bool_idx.__name__, + ravel_multidimensional_bool_idx, + "numba", +) + +optdb["specialize"].register( + ravel_multidimensional_int_idx.__name__, + ravel_multidimensional_int_idx, + "numba", +) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 5e1784f368..ff335e30dc 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -19,7 +19,7 @@ inc_subtensor, set_subtensor, ) -from tests.link.numba.test_basic import compare_numba_and_py +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) @@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds(): @pytest.mark.parametrize( "x, indices, objmode_needed", [ + # Single vector indexing (supported natively by Numba) ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (0, [1, 2, 2, 3]), @@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds(): (np.array([True, False, False])), False, ), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), + # Single multidimensional indexing (supported after specialization rewrites) + ( + as_tensor(np.arange(3 * 3).reshape((3, 3))), + (np.eye(3).astype(int)), + False, + ), ( as_tensor(np.arange(3 * 3).reshape((3, 3))), (np.eye(3).astype(bool)), + False, + ), + ( + as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), + (np.eye(3).astype(int)), + False, + ), + ( + as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), + (np.eye(3).astype(bool)), + False, + ), + ( + as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), + (slice(2, None), np.eye(3).astype(int)), + False, + ), + ( + as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), + (slice(2, None), np.eye(3).astype(bool)), + False, + ), + # Multiple advanced indexing, only supported in obj mode + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (slice(None), [1, 2], [3, 4]), True, ), - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), True, ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([[1, 2], [2, 1]], [0, 0]), + True, + ), ], ) @pytest.mark.filterwarnings("error") def test_AdvancedSubtensor(x, indices, objmode_needed): """Test NumPy's advanced indexing in more than one dimension.""" - out_pt = x[indices] + x_pt = x.type() + out_pt = x_pt[indices] assert isinstance(out_pt.owner.op, AdvancedSubtensor) - out_fg = FunctionGraph([], [out_pt]) + out_fg = FunctionGraph([x_pt], [out_pt]) with ( pytest.warns( UserWarning, @@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): if objmode_needed else contextlib.nullcontext() ): - compare_numba_and_py(out_fg, []) + compare_numba_and_py( + out_fg, + [x.data], + numba_mode=numba_mode.including("specialize"), + ) @pytest.mark.parametrize(