From 3f5b34666240798c55222830e045df648aa19143 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 15 Nov 2023 15:22:50 +0100 Subject: [PATCH 1/2] Rename replace/vectorize to replace/vectorize_graph --- pytensor/graph/__init__.py | 2 +- pytensor/graph/replace.py | 12 +++++++++--- pytensor/tensor/blockwise.py | 4 ++-- tests/graph/test_replace.py | 10 +++++----- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/pytensor/graph/__init__.py b/pytensor/graph/__init__.py index f7c4202452..189dfed237 100644 --- a/pytensor/graph/__init__.py +++ b/pytensor/graph/__init__.py @@ -9,7 +9,7 @@ clone, ancestors, ) -from pytensor.graph.replace import clone_replace, graph_replace, vectorize +from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph from pytensor.graph.op import Op from pytensor.graph.type import Type from pytensor.graph.fg import FunctionGraph diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index abaf839dbb..3c07e21232 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Iterable, Mapping, Sequence from functools import partial, singledispatch from typing import Optional, Union, cast, overload @@ -215,7 +216,7 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply: @overload -def vectorize( +def vectorize_graph( outputs: Variable, replace: Mapping[Variable, Variable], ) -> Variable: @@ -223,14 +224,14 @@ def vectorize( @overload -def vectorize( +def vectorize_graph( outputs: Sequence[Variable], replace: Mapping[Variable, Variable], ) -> Sequence[Variable]: ... -def vectorize( +def vectorize_graph( outputs: Union[Variable, Sequence[Variable]], replace: Mapping[Variable, Variable], ) -> Union[Variable, Sequence[Variable]]: @@ -309,3 +310,8 @@ def transform(var: Variable) -> Variable: else: [vect_output] = seq_vect_outputs return vect_output + + +def vectorize(*args, **kwargs): + warnings.warn("vectorize was renamed to vectorize_graph", UserWarning) + return vectorize_node(*args, **kwargs) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index d429c461d6..1ad69647fa 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -9,7 +9,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op -from pytensor.graph.replace import _vectorize_node, vectorize +from pytensor.graph.replace import _vectorize_node, vectorize_graph from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor @@ -274,7 +274,7 @@ def as_core(t, core_t): core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) - igrads = vectorize( + igrads = vectorize_graph( [core_igrad for core_igrad in core_igrads if core_igrad is not None], replace=dict( zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index aefc7d69da..6018517dbf 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -5,7 +5,7 @@ import pytensor.tensor as pt from pytensor import config, function, shared from pytensor.graph.basic import equal_computations, graph_inputs -from pytensor.graph.replace import clone_replace, graph_replace, vectorize +from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph from pytensor.tensor import dvector, fvector, vector from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable @@ -226,7 +226,7 @@ def test_graph_replace_disconnected(self): oc = graph_replace([o], {fake: x.clone()}, strict=True) -class TestVectorize: +class TestVectorizeGraph: # TODO: Add tests with multiple outputs, constants, and other singleton types def test_basic(self): @@ -234,10 +234,10 @@ def test_basic(self): y = pt.exp(x) / pt.sum(pt.exp(x)) new_x = pt.matrix("new_x") - [new_y] = vectorize([y], {x: new_x}) + [new_y] = vectorize_graph([y], {x: new_x}) # Check we can pass both a sequence or a single variable - alt_new_y = vectorize(y, {x: new_x}) + alt_new_y = vectorize_graph(y, {x: new_x}) assert equal_computations([new_y], [alt_new_y]) fn = function([new_x], new_y) @@ -253,7 +253,7 @@ def test_multiple_outputs(self): y2 = x[-1] new_x = pt.matrix("new_x") - [new_y1, new_y2] = vectorize([y1, y2], {x: new_x}) + [new_y1, new_y2] = vectorize_graph([y1, y2], {x: new_x}) fn = function([new_x], [new_y1, new_y2]) new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX) From 225f429396e40891ccf9efa226f33b5ef049f117 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 15 Nov 2023 16:35:40 +0100 Subject: [PATCH 2/2] Add functional vectorize helper to pytensor.tensor module --- pytensor/tensor/__init__.py | 1 + pytensor/tensor/blockwise.py | 46 ++---------- pytensor/tensor/functional.py | 125 ++++++++++++++++++++++++++++++++ pytensor/tensor/utils.py | 38 ++++++++++ tests/tensor/test_blockwise.py | 3 +- tests/tensor/test_functional.py | 81 +++++++++++++++++++++ 6 files changed, 252 insertions(+), 42 deletions(-) create mode 100644 pytensor/tensor/functional.py create mode 100644 tests/tensor/test_functional.py diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 3f6fae35dd..3112892697 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -148,6 +148,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: from pytensor.tensor.type import * # noqa from pytensor.tensor.type_other import * # noqa from pytensor.tensor.variable import TensorConstant, TensorVariable # noqa +from pytensor.tensor.functional import vectorize # noqa # Allow accessing numpy constants from pytensor.tensor from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi # noqa diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 1ad69647fa..8e44fbbbc6 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,4 +1,3 @@ -import re from collections.abc import Sequence from typing import Any, Optional, cast @@ -13,49 +12,14 @@ from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor -from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string +from pytensor.tensor.utils import ( + _parse_gufunc_signature, + broadcast_static_dim_lengths, + import_func_from_string, +) from pytensor.tensor.variable import TensorVariable -# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad) - -# Copied verbatim from numpy.lib.function_base -# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029 -_DIMENSION_NAME = r"\w+" -_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) -_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" -_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) -_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) - - -def _parse_gufunc_signature(signature): - """ - Parse string signatures for a generalized universal function. - - Arguments - --------- - signature : string - Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` - for ``np.matmul``. - - Returns - ------- - Tuple of input and output core dimensions parsed from the signature, each - of the form List[Tuple[str, ...]]. - """ - signature = re.sub(r"\s+", "", signature) - - if not re.match(_SIGNATURE, signature): - raise ValueError(f"not a valid gufunc signature: {signature}") - return tuple( - [ - tuple(re.findall(_DIMENSION_NAME, arg)) - for arg in re.findall(_ARGUMENT, arg_list) - ] - for arg_list in signature.split("->") - ) - - def safe_signature( core_inputs: Sequence[Variable], core_outputs: Sequence[Variable], diff --git a/pytensor/tensor/functional.py b/pytensor/tensor/functional.py new file mode 100644 index 0000000000..bc7bef6e0a --- /dev/null +++ b/pytensor/tensor/functional.py @@ -0,0 +1,125 @@ +from typing import Callable, Optional + +from pytensor.graph import vectorize_graph +from pytensor.tensor import TensorVariable +from pytensor.tensor.utils import _parse_gufunc_signature + + +def vectorize(func: Callable, signature: Optional[str] = None) -> Callable: + """Create a vectorized version of a python function that takes TensorVariables as inputs and outputs. + + Similar to numpy.vectorize. See respective docstrings for more details. + + Parameters + ---------- + func: Callable + Function that creates the desired outputs from TensorVariable inputs with the core dimensions. + signature: str, optional + Generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication. + If not provided, it is assumed all inputs have scalar core dimensions. Unlike numpy, the outputs + can have arbitrary shapes when the signature is not provided. + + Returns + ------- + vectorized_func: Callable + Callable that takes TensorVariables with arbitrarily batched dimensions on the left + and returns variables whose graphs correspond to the vectorized expressions of func. + + Notes + ----- + Unlike numpy.vectorize, the equality of core dimensions implied by the signature is not explicitly asserted. + + To vectorize an existing graph, use `pytensor.graph.replace.vectorize_graph` instead. + + + Examples + -------- + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + def func(x): + return pt.exp(x) / pt.sum(pt.exp(x)) + + vec_func = pt.vectorize(func, signature="(a)->(a)") + + x = pt.matrix("x") + y = vec_func(x) + + fn = pytensor.function([x], y) + fn([[0, 1, 2], [2, 1, 0]]) + # array([[0.09003057, 0.24472847, 0.66524096], + # [0.66524096, 0.24472847, 0.09003057]]) + + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + def func(x): + return x[0], x[-1] + + vec_func = pt.vectorize(func, signature="(a)->(),()") + + x = pt.matrix("x") + y1, y2 = vec_func(x) + + fn = pytensor.function([x], [y1, y2]) + fn([[-10, 0, 10], [-11, 0, 11]]) + # [array([-10., -11.]), array([10., 11.])] + + """ + + def inner(*inputs): + if signature is None: + # Assume all inputs are scalar + inputs_sig = [()] * len(inputs) + else: + inputs_sig, outputs_sig = _parse_gufunc_signature(signature) + if len(inputs) != len(inputs_sig): + raise ValueError( + f"Number of inputs does not match signature: {signature}" + ) + + # Create dummy core inputs by stripping the batched dimensions of inputs + core_inputs = [] + for input, input_sig in zip(inputs, inputs_sig): + if not isinstance(input, TensorVariable): + raise TypeError( + f"Inputs to vectorize function must be TensorVariable, got {type(input)}" + ) + + if input.ndim < len(input_sig): + raise ValueError( + f"Input {input} has less dimensions than signature {input_sig}" + ) + if len(input_sig): + core_shape = input.type.shape[-len(input_sig) :] + else: + core_shape = () + + core_input = input.type.clone(shape=core_shape)(name=input.name) + core_inputs.append(core_input) + + # Call function on dummy core inputs + core_outputs = func(*core_inputs) + if core_outputs is None: + raise ValueError("vectorize function returned no outputs") + + if signature is not None: + if isinstance(core_outputs, (list, tuple)): + n_core_outputs = len(core_outputs) + else: + n_core_outputs = 1 + if n_core_outputs != len(outputs_sig): + raise ValueError( + f"Number of outputs does not match signature: {signature}" + ) + + # Vectorize graph by replacing dummy core inputs by original inputs + outputs = vectorize_graph(core_outputs, replace=dict(zip(core_inputs, inputs))) + return outputs + + return inner diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 592a9997f6..ed00663387 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,3 +1,4 @@ +import re from collections.abc import Sequence from typing import Union @@ -161,3 +162,40 @@ def broadcast_static_dim_lengths( if len(dim_lengths_set) > 1: raise ValueError return tuple(dim_lengths_set)[0] + + +# Copied verbatim from numpy.lib.function_base +# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029 +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _parse_gufunc_signature(signature): + """ + Parse string signatures for a generalized universal function. + + Arguments + --------- + signature : string + Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` + for ``np.matmul``. + + Returns + ------- + Tuple of input and output core dimensions parsed from the signature, each + of the form List[Tuple[str, ...]]. + """ + signature = re.sub(r"\s+", "", signature) + + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + return tuple( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index d761641b49..345dfa0873 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -10,9 +10,10 @@ from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node from pytensor.tensor import diagonal, log, tensor -from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular +from pytensor.tensor.utils import _parse_gufunc_signature def test_vectorize_blockwise(): diff --git a/tests/tensor/test_functional.py b/tests/tensor/test_functional.py new file mode 100644 index 0000000000..df0472f678 --- /dev/null +++ b/tests/tensor/test_functional.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest + +from pytensor.graph.basic import equal_computations +from pytensor.tensor import full, tensor +from pytensor.tensor.functional import vectorize +from pytensor.tensor.random.type import RandomGeneratorType + + +class TestVectorize: + def test_vectorize_no_signature(self): + """Unlike numpy we don't assume outputs of vectorize without signature are scalar.""" + + def func(x): + return full((5, 3), x) + + vec_func = vectorize(func) + + x = tensor("x", shape=(4,), dtype="float64") + out = vec_func(x) + + assert out.type.ndim == 3 + test_x = np.array([1, 2, 3, 4]) + np.testing.assert_allclose( + out.eval({x: test_x}), np.full((len(test_x), 5, 3), test_x[:, None, None]) + ) + + def test_vectorize_outer_product(self): + def func(x, y): + return x[:, None] * y[None, :] + + vec_func = vectorize(func, signature="(a),(b)->(a,b)") + + x = tensor("x", shape=(2, 3, 5)) + y = tensor("y", shape=(2, 3, 7)) + out = vec_func(x, y) + + assert out.type.shape == (2, 3, 5, 7) + assert equal_computations([out], [x[..., :, None] * y[..., None, :]]) + + def test_vectorize_outer_inner_product(self): + def func(x, y): + return x[:, None] * y[None, :], (x * y).sum() + + vec_func = vectorize(func, signature="(a),(b)->(a,b),()") + + x = tensor("x", shape=(2, 3, 5)) + y = tensor("y", shape=(2, 3, 5)) + outer, inner = vec_func(x, y) + + assert outer.type.shape == (2, 3, 5, 5) + assert inner.type.shape == (2, 3) + assert equal_computations([outer], [x[..., :, None] * y[..., None, :]]) + assert equal_computations([inner], [(x * y).sum(axis=-1)]) + + def test_errors(self): + def func(x, y): + return x + y, x - y + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=()) + + with pytest.raises(ValueError, match="Number of inputs"): + vectorize(func, signature="(),()->()")(x) + + with pytest.raises(ValueError, match="Number of outputs"): + vectorize(func, signature="(),()->()")(x, y) + + with pytest.raises(ValueError, match="Input y has less dimensions"): + vectorize(func, signature="(a),(a)->(a),(a)")(x, y) + + bad_input = RandomGeneratorType() + + with pytest.raises(TypeError, match="must be TensorVariable"): + vectorize(func)(bad_input, x) + + def bad_func(x, y): + x + y + + with pytest.raises(ValueError, match="no outputs"): + vectorize(bad_func)(x, y)