Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
Expand Down
12 changes: 9 additions & 3 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from typing import Optional, Union, cast, overload
Expand Down Expand Up @@ -215,22 +216,22 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:


@overload
def vectorize(
def vectorize_graph(
outputs: Variable,
replace: Mapping[Variable, Variable],
) -> Variable:
...


@overload
def vectorize(
def vectorize_graph(
outputs: Sequence[Variable],
replace: Mapping[Variable, Variable],
) -> Sequence[Variable]:
...


def vectorize(
def vectorize_graph(
outputs: Union[Variable, Sequence[Variable]],
replace: Mapping[Variable, Variable],
) -> Union[Variable, Sequence[Variable]]:
Expand Down Expand Up @@ -309,3 +310,8 @@ def transform(var: Variable) -> Variable:
else:
[vect_output] = seq_vect_outputs
return vect_output


def vectorize(*args, **kwargs):
warnings.warn("vectorize was renamed to vectorize_graph", UserWarning)
return vectorize_node(*args, **kwargs)
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
from pytensor.tensor.type import * # noqa
from pytensor.tensor.type_other import * # noqa
from pytensor.tensor.variable import TensorConstant, TensorVariable # noqa
from pytensor.tensor.functional import vectorize # noqa

# Allow accessing numpy constants from pytensor.tensor
from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi # noqa
Expand Down
50 changes: 7 additions & 43 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, cast

Expand All @@ -9,53 +8,18 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
from pytensor.tensor.utils import (
_parse_gufunc_signature,
broadcast_static_dim_lengths,
import_func_from_string,
)
from pytensor.tensor.variable import TensorVariable


# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad)

# Copied verbatim from numpy.lib.function_base
# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029
_DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME)
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT)
_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST)


def _parse_gufunc_signature(signature):
"""
Parse string signatures for a generalized universal function.

Arguments
---------
signature : string
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
for ``np.matmul``.

Returns
-------
Tuple of input and output core dimensions parsed from the signature, each
of the form List[Tuple[str, ...]].
"""
signature = re.sub(r"\s+", "", signature)

if not re.match(_SIGNATURE, signature):
raise ValueError(f"not a valid gufunc signature: {signature}")
return tuple(
[
tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)
]
for arg_list in signature.split("->")
)


def safe_signature(
core_inputs: Sequence[Variable],
core_outputs: Sequence[Variable],
Expand Down Expand Up @@ -274,7 +238,7 @@ def as_core(t, core_t):

core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)

igrads = vectorize(
igrads = vectorize_graph(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
replace=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
Expand Down
125 changes: 125 additions & 0 deletions pytensor/tensor/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Callable, Optional

from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.utils import _parse_gufunc_signature


def vectorize(func: Callable, signature: Optional[str] = None) -> Callable:
"""Create a vectorized version of a python function that takes TensorVariables as inputs and outputs.

Similar to numpy.vectorize. See respective docstrings for more details.

Parameters
----------
func: Callable
Function that creates the desired outputs from TensorVariable inputs with the core dimensions.
signature: str, optional
Generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication.
If not provided, it is assumed all inputs have scalar core dimensions. Unlike numpy, the outputs
can have arbitrary shapes when the signature is not provided.

Returns
-------
vectorized_func: Callable
Callable that takes TensorVariables with arbitrarily batched dimensions on the left
and returns variables whose graphs correspond to the vectorized expressions of func.

Notes
-----
Unlike numpy.vectorize, the equality of core dimensions implied by the signature is not explicitly asserted.

To vectorize an existing graph, use `pytensor.graph.replace.vectorize_graph` instead.


Examples
--------
.. code-block:: python

import pytensor
import pytensor.tensor as pt

def func(x):
return pt.exp(x) / pt.sum(pt.exp(x))

vec_func = pt.vectorize(func, signature="(a)->(a)")

x = pt.matrix("x")
y = vec_func(x)

fn = pytensor.function([x], y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])


.. code-block:: python

import pytensor
import pytensor.tensor as pt

def func(x):
return x[0], x[-1]

vec_func = pt.vectorize(func, signature="(a)->(),()")

x = pt.matrix("x")
y1, y2 = vec_func(x)

fn = pytensor.function([x], [y1, y2])
fn([[-10, 0, 10], [-11, 0, 11]])
# [array([-10., -11.]), array([10., 11.])]

"""

def inner(*inputs):
if signature is None:
# Assume all inputs are scalar
inputs_sig = [()] * len(inputs)
else:
inputs_sig, outputs_sig = _parse_gufunc_signature(signature)
if len(inputs) != len(inputs_sig):
raise ValueError(
f"Number of inputs does not match signature: {signature}"
)

# Create dummy core inputs by stripping the batched dimensions of inputs
core_inputs = []
for input, input_sig in zip(inputs, inputs_sig):
if not isinstance(input, TensorVariable):
raise TypeError(
f"Inputs to vectorize function must be TensorVariable, got {type(input)}"
)

if input.ndim < len(input_sig):
raise ValueError(
f"Input {input} has less dimensions than signature {input_sig}"
)
if len(input_sig):
core_shape = input.type.shape[-len(input_sig) :]
else:
core_shape = ()

core_input = input.type.clone(shape=core_shape)(name=input.name)
core_inputs.append(core_input)

# Call function on dummy core inputs
core_outputs = func(*core_inputs)
if core_outputs is None:
raise ValueError("vectorize function returned no outputs")

if signature is not None:
if isinstance(core_outputs, (list, tuple)):
n_core_outputs = len(core_outputs)
else:
n_core_outputs = 1
if n_core_outputs != len(outputs_sig):
raise ValueError(
f"Number of outputs does not match signature: {signature}"
)

# Vectorize graph by replacing dummy core inputs by original inputs
outputs = vectorize_graph(core_outputs, replace=dict(zip(core_inputs, inputs)))
return outputs

return inner
38 changes: 38 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from collections.abc import Sequence
from typing import Union

Expand Down Expand Up @@ -161,3 +162,40 @@ def broadcast_static_dim_lengths(
if len(dim_lengths_set) > 1:
raise ValueError
return tuple(dim_lengths_set)[0]


# Copied verbatim from numpy.lib.function_base
# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029
_DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME)
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT)
_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST)


def _parse_gufunc_signature(signature):
"""
Parse string signatures for a generalized universal function.

Arguments
---------
signature : string
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
for ``np.matmul``.

Returns
-------
Tuple of input and output core dimensions parsed from the signature, each
of the form List[Tuple[str, ...]].
"""
signature = re.sub(r"\s+", "", signature)

if not re.match(_SIGNATURE, signature):
raise ValueError(f"not a valid gufunc signature: {signature}")
return tuple(
[
tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)
]
for arg_list in signature.split("->")
)
10 changes: 5 additions & 5 deletions tests/graph/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
Expand Down Expand Up @@ -226,18 +226,18 @@ def test_graph_replace_disconnected(self):
oc = graph_replace([o], {fake: x.clone()}, strict=True)


class TestVectorize:
class TestVectorizeGraph:
# TODO: Add tests with multiple outputs, constants, and other singleton types

def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))

new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
[new_y] = vectorize_graph([y], {x: new_x})

# Check we can pass both a sequence or a single variable
alt_new_y = vectorize(y, {x: new_x})
alt_new_y = vectorize_graph(y, {x: new_x})
assert equal_computations([new_y], [alt_new_y])

fn = function([new_x], new_y)
Expand All @@ -253,7 +253,7 @@ def test_multiple_outputs(self):
y2 = x[-1]

new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
[new_y1, new_y2] = vectorize_graph([y1, y2], {x: new_x})

fn = function([new_x], [new_y1, new_y2])
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
Expand Down
3 changes: 2 additions & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature


def test_vectorize_blockwise():
Expand Down
Loading