diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 7d4f7e39f3..0a8b15223d 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -1,12 +1,11 @@ """Define new Ops from existing Ops""" -from collections import OrderedDict -from collections.abc import Sequence +import warnings +from collections.abc import Callable, Sequence from copy import copy from functools import partial -from typing import cast +from typing import Union, cast -import pytensor.tensor as pt from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.mode import optdb @@ -160,16 +159,9 @@ class OpFromGraph(Op, HasInnerGraph): Currently does not support ``updates`` or ``givens`` argument. .. TODO: - - examples for a multi-layer mlp. where? - - __hash__, __eq__ otherwise won't merge, try - is_same_graph_with_merge(op1.local_outputs, op2, - local_outputs) - - c_code() to remove the double overhead? - - grad() make it support DisconnectedType and the new interface - - add support for NullType and DisconnectedType when R_op supports them - - check how it works with updates. + - Allow / test merging of OpFromGraph nodes + - Add support for NullType and DisconnectedType when R_op supports them - Add support to pickle this Op. - - Add support/test with random generator - Add optimization to removing unused inputs/outputs - Add optimization to work inplace on inputs when not inline @@ -186,7 +178,7 @@ class OpFromGraph(Op, HasInnerGraph): - For overriding, it's recommended to provide pure functions (no side effects like setting global variable) as callable(s). The callable(s) supplied for overriding gradient/rop will be called only once at the - first call to grad/R_op, and will be converted to OpFromGraph instances. + first call to L_op/R_op, and will be converted to OpFromGraph instances. Examples -------- @@ -221,7 +213,7 @@ class OpFromGraph(Op, HasInnerGraph): e2 = op(x, y, z) + op(z, y, x) fn = function([x, y, z], [e2]) - Example 3 override gradient + Example 3 override second output of L_op .. code-block:: python @@ -230,12 +222,15 @@ class OpFromGraph(Op, HasInnerGraph): x, y, z = pt.scalars('xyz') e = x + y * z - def rescale_dy(inps, grads): + def rescale_dy(inps, outputs, out_grads): x, y, z = inps - g, = grads + g, = out_grads return z*2 op = OpFromGraph( - [x, y, z], [e], grad_overrides=['default', rescale_dy, 'default'] + [x, y, z], + [e], + lop_overrides=[None, rescale_dy, None], + ) e2 = op(x, y, z) dx, dy, dz = grad(e2, [x, y, z]) fn = function([x, y, z], [dx, dy, dz]) @@ -244,66 +239,15 @@ def rescale_dy(inps, grads): """ - TYPE_ERR_MSG = ( - "L_op/gradient override should be (single or list of)" - "'default' | OpFromGraph | callable | Variable " - "with NullType or DisconnectedType, got %s" - ) - STYPE_ERR_MSG = ( - "Overriding Variable instance can only have type" - " of DisconnectedType or NullType, got %s" - ) - LOP_TYPE_ERR_MSG = 'L_op type can only be "grad" or "lop", got %s.' - OV_INP_LEN_ERR_MSG = "expect overrider with %d inputs, got %d" - - @staticmethod - def _filter_grad_var(grad, inp): - # Returns (filtered_var, overrider_var) - # Args: - # grad: gradient Variable - # inp: the corresponding input of gradient Variable - # - # a grad() call could return instance of NullType() or DisconnectedType() - # which cannot be directly used in OfG - # - # Since we always use an OfG instance as self._lop_op, the current - # workaround is to "remember" the special cases of the gradient and - # replace them after self._lop_op is called. - # - # This helper function changes invalid types into a filtered_var, - # and provides a overrider_var to be replaced at grad() call - # - # For now, this converts NullType or DisconnectedType into zeros_like. - # other types are unmodified: overrider_var -> None - if isinstance(grad.type, NullType | DisconnectedType): - if hasattr(inp, "zeros_like"): - return inp.zeros_like(), grad - else: - return pt.constant(0.0), grad - else: - return grad, None - - @staticmethod - def _filter_rop_var(inpJ, out): - # mostly similar to _filter_grad_var - if isinstance(inpJ.type, NullType): - return out.zeros_like(), inpJ - if isinstance(inpJ.type, DisconnectedType): - # since R_op does not have DisconnectedType yet, we will just - # make them zeros. - return out.zeros_like(), None - else: - return inpJ, None - def __init__( self, inputs: list[Variable], outputs: list[Variable], *, inline: bool = False, - lop_overrides: str = "default", - grad_overrides: str = "default", - rop_overrides: str = "default", + lop_overrides: Union[Callable, "OpFromGraph", None] = None, + grad_overrides: Union[Callable, "OpFromGraph", None] = None, + rop_overrides: Union[Callable, "OpFromGraph", None] = None, connection_pattern: list[list[bool]] | None = None, strict: bool = False, name: str | None = None, @@ -314,8 +258,10 @@ def __init__( ---------- inputs The inputs to the graph. + outputs The outputs to the graph. + inline Defaults to ``False`` @@ -324,11 +270,12 @@ def __init__( graph but rather its internal graph. ``False`` : will use a pre-compiled function inside. + grad_overrides - Defaults to ``'default'``. + Defaults to ``None``. This argument is mutually exclusive with ``lop_overrides``. - ``'default'`` : Do not override, use default grad() result + ``None`` : Do not override, use default grad() result `OpFromGraph`: Override with another `OpFromGraph`, should accept inputs as the same order and types of ``inputs`` and ``output_grads`` @@ -337,15 +284,16 @@ def __init__( `callable`: Should take two args: ``inputs`` and ``output_grads``. Each argument is expected to be a list of :class:`Variable `. Must return list of :class:`Variable `. + lop_overrides - Defaults to ``'default'``. + Defaults to ``None``. This argument is mutually exclusive with ``grad_overrides``. These options are similar to the ``grad_overrides`` above, but for the :meth:`Op.L_op` method. - ``'default'``: Do not override, use the default :meth:`Op.L_op` result + ``None``: Do not override, use the default :meth:`Op.L_op` result `OpFromGraph`: Override with another `OpFromGraph`, should accept inputs as the same order and types of ``inputs``, @@ -356,20 +304,16 @@ def __init__( Each argument is expected to be a list of :class:`Variable`. Must return list of :class:`Variable`. - `NullType` instance: Treat as non-differentiable - `DisconnectedType` instance: Treat as disconnected gradient, - numerically gives zero - ``list``: Each `OpFromGraph`/callable must return a single :class:`Variable`. Each list element corresponds to gradient of a specific input, length of list must be equal to number of inputs. rop_overrides - One of ``{'default', OpFromGraph, callable, Variable}``. + One of ``{None, OpFromGraph, callable, Variable}``. - Defaults to ``'default'``. + Defaults to ``None``. - ``'default'``: Do not override, use the default :meth:`Op.R_op` result + ``None``: Do not override, use the default :meth:`Op.R_op` result `OpFromGraph`: Override with another `OpFromGraph`, should accept inputs as the same order and types of ``inputs`` and ``eval_points`` @@ -379,10 +323,6 @@ def __init__( Each argument is expected to be a list of :class:`Variable`. Must return list of :class:`Variable`. - `NullType` instance: Treat as non-differentiable `DisconnectedType` - instance: Treat as zero since `DisconnectedType` is not yet supported - in :meth:`Op.R_op`. - ``list``: Each :class:`OpFromGraph`/callable must return a single :class:`Variable `. Each list element @@ -390,12 +330,15 @@ def __init__( must be equal to number of outputs. connection_pattern If not ``None``, this will be used as the connection_pattern for this :class:`Op`. + strict: bool, default False If true, it raises when any variables needed to compute the inner graph are not provided as explici inputs. This can only happen for graphs with shared variables. + name A name for debugging purposes. + kwargs Check :func:`pytensor.function` for more arguments, only works when not inline. @@ -438,26 +381,36 @@ def __init__( self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] + for override in (lop_overrides, grad_overrides, rop_overrides): + if override == "default": + raise ValueError( + "'default' is no longer a valid value for overrides. Use None instead." + ) + if isinstance(override, Variable): + raise TypeError( + "Variables are no longer valid types for overrides. Return them in a list for each output instead" + ) + self.lop_overrides = lop_overrides self.grad_overrides = grad_overrides self.rop_overrides = rop_overrides - if lop_overrides != "default": - if grad_overrides != "default": + self._lop_op_interface = True + if grad_overrides is not None: + if lop_overrides is not None: raise ValueError( "lop_overrides and grad_overrides are mutually exclusive" ) - else: - self.set_lop_overrides(lop_overrides) - self._lop_type = "lop" - elif grad_overrides != "default": - self.set_lop_overrides(grad_overrides) - self._lop_type = "grad" - else: - self.set_lop_overrides("default") - self._lop_type = "lop" - - self.set_rop_overrides(rop_overrides) + warnings.warn( + "grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.", + FutureWarning, + ) + self._lop_op_interface = False + # Dictionary where we cache OpFromGraph that represent the L_op + # A distinct OpFromGraph is needed to represent each pattern of output_grads connection + # It also returns a tuple that indicates which input_gradients are disconnected + self._lop_op_cache: dict[tuple[bool, ...], Callable] = {} + self._rop_op_cache: Callable | None = None self._connection_pattern = connection_pattern @@ -478,327 +431,274 @@ def __str__(self): is_inline = self.is_inline return "{name}{{inline={is_inline}}}".format(**locals()) + def _combine_list_overrides(self, default_outs, custom_outs, callable_args): + """Combines default and custom overrides into a single list of outputs.""" + default_out_iter = iter(default_outs) + combined_outs = [] + for custom_out in custom_outs: + if custom_out is None: + combined_outs.append(next(default_out_iter)) + elif isinstance(custom_out, Variable): + if not isinstance(custom_out.type, NullType | DisconnectedType): + raise ValueError( + f"Override list can only contain NullType or DisconnectedType Variable instances, got {custom_out.type}" + ) + combined_outs.append(custom_out) + elif callable(custom_out): + combined_outs.append(custom_out(*callable_args)) + else: + raise ValueError( + f"Override list should contain None, Variable or callable, got {type(custom_out)}" + ) + return combined_outs + + def _call_custom_override(self, op_overrides, callable_args, nout): + """Calls custom override function and provides informative error messages.""" + if not callable(op_overrides): + raise TypeError( + f"L_op/R_op override should be None, a list or a Callable, got {type(op_overrides)}" + ) + outputs = op_overrides(*callable_args) + if not isinstance(outputs, list): + raise TypeError( + f"Lop/Rop overriding function should return a list, got {type(outputs)}" + ) + if len(outputs) != nout: + raise ValueError( + f"Lop/Rop overriding function {self.rop_overrides} should return " + f"a list of {nout} outputs, got {len(outputs)}" + ) + return outputs + @config.change_flags(compute_test_value="off") - def _recompute_lop_op(self): - """ - converts self._lop_op from user supplied form to type(self) instance + def _build_and_cache_lop_op( + self, disconnected_output_grads: tuple[bool, ...] + ) -> Callable: + """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance, + specialized for the pattern of disconnected_output_grads + Results are cached in self._lop_op_cache """ - local_inputs = self.inner_inputs - local_outputs = self.inner_outputs - inp_len = len(local_inputs) - lop_op = self._lop_op - - if isinstance(lop_op, OpFromGraph): - if self._lop_op_is_cached: - return - assert self._lop_type in ("lop", "grad"), ( - self.LOP_TYPE_ERR_MSG % self._lop_type + try: + return self._lop_op_cache[disconnected_output_grads] + except KeyError: + pass + + inner_inputs = self.inner_inputs + inner_outputs = self.inner_outputs + nin = len(inner_inputs) + nout = len(inner_outputs) + lop_overrides = ( + self.lop_overrides if self._lop_op_interface else self.grad_overrides + ) + + if isinstance(lop_overrides, OpFromGraph): + if self._lop_op_interface: + self._lop_op_cache[disconnected_output_grads] = lop_overrides + lop_overrides.kwargs["on_unused_input"] = "ignore" + return lop_overrides + + else: + # We need to add a wrapper for the different input signature + # TODO: Remove this once the grad interface is gone + def lop_overrides(inps, grads): + return self.grad_overrides(*inps, *grads) + + # We try to compute the gradient with respect to connected outputs only + connected_inner_outputs = [ + # We add an identity operation(copy) so that we don't override indirect + # gradient contributions to an inner output coming from other inner outputs + inner_out.copy() + for inner_out, disconnected in zip( + inner_outputs, disconnected_output_grads, strict=True + ) + if not disconnected + ] + connected_output_grads = [ + out_t() + for out_t, disconnected in zip( + self.output_types, disconnected_output_grads, strict=True ) - if self._lop_type == "grad": - needed_ninps = inp_len + len(local_outputs) - ninps = len(lop_op.inner_inputs) - if needed_ninps != ninps: - raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) - # make a wrapper callable - - def lop_op(inps, grads): - return self._lop_op(*(inps + grads)) - - elif self._lop_type == "lop": - # OfG can be directly used in L_op format - needed_ninps = inp_len + 2 * len(local_outputs) - ninps = len(lop_op.inner_inputs) - if needed_ninps != ninps: - raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) - self._lop_op_is_cached = True - self._lop_op_stypes_l = [None] * inp_len - self._lop_op.kwargs["on_unused_input"] = "ignore" - return - - output_grads = [out_t() for out_t in self.output_types] + if not disconnected + ] fn_grad = partial( grad, cost=None, disconnected_inputs="ignore", - return_disconnected="Disconnected", + return_disconnected="disconnected", null_gradients="return", - known_grads=OrderedDict(zip(local_outputs, output_grads)), + known_grads=dict( + zip(connected_inner_outputs, connected_output_grads, strict=True) + ), ) - assert self._lop_type in ("lop", "grad"), self.LOP_TYPE_ERR_MSG % self._lop_type - if self._lop_type == "lop": - callable_args = (local_inputs, local_outputs, output_grads) - elif self._lop_type == "grad": - callable_args = (local_inputs, output_grads) + if self._lop_op_interface: + callable_args = ( + inner_inputs, + connected_inner_outputs, + connected_output_grads, + ) + else: + callable_args = (inner_inputs, connected_output_grads) # we need to convert _lop_op into an OfG instance - if lop_op == "default": - gdefaults_l = fn_grad(wrt=local_inputs) - all_grads_l, all_grads_ov_l = zip( - *[ - OpFromGraph._filter_grad_var(grad, inp) - for grad, inp in zip(gdefaults_l, local_inputs) - ] - ) - all_grads_l = list(all_grads_l) - all_grads_ov_l = list(all_grads_ov_l) - elif isinstance(lop_op, Variable): - if isinstance(lop_op.type, DisconnectedType | NullType): - all_grads_l = [inp.zeros_like() for inp in local_inputs] - all_grads_ov_l = [lop_op.type() for _ in range(inp_len)] - else: - raise ValueError(self.STYPE_ERR_MSG % lop_op.type) - elif isinstance(lop_op, list): - goverrides_l = lop_op - if len(goverrides_l) != inp_len: + if lop_overrides is None: + input_grads = fn_grad(wrt=inner_inputs) + elif isinstance(lop_overrides, list): + custom_input_grads = lop_overrides + if len(custom_input_grads) != nin: raise ValueError( - f"Need to override {int(inp_len)} gradients, got {len(goverrides_l)}", - goverrides_l, + f"Need to override {nin} gradients, got {len(custom_input_grads)}", + custom_input_grads, ) # compute non-overriding downsteam grads from upstreams grads # it's normal some input may be disconnected, thus the 'ignore' - wrt_l = [ - lin for lin, gov in zip(local_inputs, goverrides_l) if gov == "default" + wrt = [ + lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None ] - gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else []) - # combine overriding gradients - all_grads_l = [] - all_grads_ov_l = [] - for inp, fn_gov in zip(local_inputs, goverrides_l): - if fn_gov == "default": - gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp) - all_grads_l.append(gnext) - all_grads_ov_l.append(gnext_ov) - elif isinstance(fn_gov, Variable): - if isinstance(fn_gov.type, DisconnectedType | NullType): - all_grads_l.append(inp.zeros_like()) - all_grads_ov_l.append(fn_gov.type()) - else: - raise ValueError(self.STYPE_ERR_MSG % fn_gov.type) - else: - if not callable(fn_gov): - raise TypeError(self.TYPE_ERR_MSG % fn_gov) - gov, gov_ov = OpFromGraph._filter_grad_var( - fn_gov(*callable_args), inp - ) - all_grads_l.append(gov) - all_grads_ov_l.append(gov_ov) - else: - # callable case - if not callable(lop_op): - raise TypeError(self.TYPE_ERR_MSG % lop_op) - goverrides_l = lop_op(*callable_args) - if not isinstance(goverrides_l, list): - raise TypeError( - "Gradient/L_op overriding function should return a list, " - f'got "{type(goverrides_l)}"' - ) - all_grads_l, all_grads_ov_l = zip( - *[ - OpFromGraph._filter_grad_var(grad, inp) - for grad, inp in zip(goverrides_l, local_inputs) - ] + default_input_grads = fn_grad(wrt=wrt) if wrt else [] + input_grads = self._combine_list_overrides( + default_input_grads, custom_input_grads, callable_args ) - if len(all_grads_l) != len(local_inputs): - raise ValueError( - "Gradient/L_op overriding function should return list of " - f"{int(inp_len)} outputs, got {len(all_grads_l)}" - ) - all_grads_l = list(all_grads_l) - all_grads_ov_l = list(all_grads_ov_l) - self._lop_op = type(self)( - inputs=local_inputs + local_outputs + output_grads, - outputs=all_grads_l, + else: + input_grads = self._call_custom_override(lop_overrides, callable_args, nin) + + # Filter out disconnected/null input generated from the inner graph grad + # We append them in the outer wrapper function below + connected_input_grads = [ + inp_grad + for inp_grad in input_grads + if not isinstance(inp_grad.type, DisconnectedType | NullType) + ] + lop_op = type(self)( + inputs=inner_inputs + connected_inner_outputs + connected_output_grads, + outputs=connected_input_grads, inline=self.is_inline, - name=(None if self.name is None else self.name + "_" + self._lop_type), + name=(None if self.name is None else f"{self.name}_LOp"), + # TODO: We can be eager here and exclude unused inputs in the OFG on_unused_input="ignore", ) - self._lop_op_stypes_l = all_grads_ov_l - self._lop_op_is_cached = True - self._lop_type = "lop" + + # Return a wrapper that combines connected and disconnected/null input gradients + # And also filters out disconnected/null output gradients + def wrapper(*inputs: Variable, **kwargs) -> list[Variable]: + inputs, outputs, output_grads = ( + inputs[: -nout * 2], + inputs[-nout * 2 : -nout], + inputs[-nout:], + ) + connected_outputs = [ + output + for output, output_grad in zip(outputs, output_grads, strict=True) + if not isinstance(output_grad.type, DisconnectedType | NullType) + ] + connected_output_grads = [ + output_grad + for output_grad in output_grads + if not isinstance(output_grad.type, DisconnectedType) + ] + connected_input_grads = iter( + lop_op(*inputs, *connected_outputs, *connected_output_grads, **kwargs) + ) + return [ + input_grad + if isinstance(input_grad.type, DisconnectedType | NullType) + else next(connected_input_grads) + for input_grad in input_grads + ] + + self._lop_op_cache[disconnected_output_grads] = wrapper + return wrapper @config.change_flags(compute_test_value="off") - def _recompute_rop_op(self): - """ - converts self._rop_op from user supplied form to type(self) instance + def _build_and_cache_rop_op(self): + """Converts rop_overrides from user supplied form to type(self) instance. + Results are cached in self._rop_op_cache """ - local_inputs = self.inner_inputs - local_outputs = self.inner_outputs - out_len = len(local_outputs) - rop_op = self._rop_op - - if isinstance(rop_op, OpFromGraph): - if not self._rop_op_is_cached: - self._rop_op_is_cached = True - self._rop_op_stypes_l = [None] * out_len - return + if self._rop_op_cache is not None: + return self._rop_op_cache + + inner_inputs = self.inner_inputs + inner_outputs = self.inner_outputs + nout = len(inner_outputs) + rop_overrides = self.rop_overrides + + if isinstance(rop_overrides, OpFromGraph): + self._rop_op_cache = rop_overrides + return rop_overrides eval_points = [inp_t() for inp_t in self.input_types] - fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points) - TYPE_ERR_MSG = ( - "R_op overrides should be (single or list of)" - "OpFromGraph | 'default' | None | 0 | callable, got %s" - ) - STYPE_ERR_MSG = ( - "Overriding Variable instance can only have type" - " of DisconnectedType or NullType, got %s" - ) - if rop_op == "default": - rdefaults_l = fn_rop(f=local_outputs) - all_rops_l, all_rops_ov_l = zip( - *[ - OpFromGraph._filter_rop_var(rop, out) - for rop, out in zip(rdefaults_l, local_outputs) - ] - ) - all_rops_l = list(all_rops_l) - all_rops_ov_l = list(all_rops_ov_l) - elif isinstance(rop_op, Variable): - if isinstance(rop_op.type, NullType): - all_rops_l = [inp.zeros_like() for inp in local_inputs] - all_rops_ov_l = [rop_op.type() for _ in range(out_len)] - elif isinstance(rop_op.type, DisconnectedType): - all_rops_l = [inp.zeros_like() for inp in local_inputs] - all_rops_ov_l = [None] * out_len - else: - raise ValueError(STYPE_ERR_MSG % rop_op.type) - elif isinstance(rop_op, list): - roverrides_l = rop_op - if len(roverrides_l) != out_len: + fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points) + + callable_args = (inner_inputs, eval_points) + if rop_overrides is None: + output_grads = fn_rop(f=inner_outputs) + elif isinstance(rop_overrides, list): + custom_output_grads = rop_overrides + if len(custom_output_grads) != nout: raise ValueError( - f"Need to override {int(out_len)} Rop, got {len(roverrides_l)}", - roverrides_l, + f"Need to override {int(nout)} Rop, got {len(custom_output_grads)}", + custom_output_grads, ) # get outputs that does not have Rop override - odefaults_l = [ - lo for lo, rov in zip(local_outputs, roverrides_l) if rov == "default" + f = [ + output + for output, custom_output_grad in zip( + inner_outputs, custom_output_grads + ) + if custom_output_grad is None ] - rdefaults_l = fn_rop(f=odefaults_l) - rdefaults = iter(rdefaults_l if odefaults_l else []) - # combine overriding Rops - all_rops_l = [] - all_rops_ov_l = [] - for out, fn_rov in zip(local_outputs, roverrides_l): - if fn_rov == "default": - rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out) - all_rops_l.append(rnext) - all_rops_ov_l.append(rnext_ov) - elif isinstance(fn_rov, Variable): - if isinstance(fn_rov.type, NullType): - all_rops_l.append(out.zeros_like()) - all_rops_ov_l.append(fn_rov.type()) - if isinstance(fn_rov.type, DisconnectedType): - all_rops_l.append(out.zeros_like()) - all_rops_ov_l.append(None) - else: - raise ValueError(STYPE_ERR_MSG % fn_rov.type) - else: - if not callable(fn_rov): - raise TypeError(TYPE_ERR_MSG % fn_rov) - rov, rov_ov = OpFromGraph._filter_rop_var( - fn_rov(local_inputs, eval_points), out - ) - all_rops_l.append(rov) - all_rops_ov_l.append(rov_ov) + default_output_grads = fn_rop(f=f) if f else [] + output_grads = self._combine_list_overrides( + default_output_grads, custom_output_grads, callable_args + ) else: - if not callable(rop_op): - raise TypeError(TYPE_ERR_MSG % rop_op) - roverrides_l = rop_op(local_inputs, eval_points) - if not isinstance(roverrides_l, list): - raise TypeError( - "Rop overriding function should return a list, " - f'got "{type(roverrides_l)}"' - ) - all_rops_l, all_rops_ov_l = zip( - *[ - OpFromGraph._filter_rop_var(rop, out) - for rop, out in zip(roverrides_l, local_outputs) - ] + output_grads = self._call_custom_override( + rop_overrides, callable_args, nout ) - if len(all_rops_l) != out_len: - raise ValueError( - ( - f"Rop overriding function {self._rop_op} should return list of " - f"{int(out_len)} outputs, got {len(all_rops_l)}", - ), - rop_op, - ) - all_rops_l = list(all_rops_l) - all_rops_ov_l = list(all_rops_ov_l) - self._rop_op = type(self)( - inputs=local_inputs + eval_points, - outputs=all_rops_l, + + # Filter out disconnected output gradients + filtered_output_grads = [ + out_grad + for out_grad in output_grads + if not isinstance(out_grad.type, DisconnectedType | NullType) + ] + rop_op = type(self)( + inputs=inner_inputs + eval_points, + outputs=filtered_output_grads, inline=self.is_inline, name=(None if self.name is None else self.name + "_rop"), on_unused_input="ignore", ) - self._rop_op_stypes_l = all_rops_ov_l - self._rop_op_is_cached = True - - def get_lop_op(self): - if not self._lop_op_is_cached: - self._recompute_lop_op() - return self._lop_op - - def get_rop_op(self): - if not self._rop_op_is_cached: - self._recompute_rop_op() - return self._rop_op - - def set_grad_overrides(self, grad_overrides): - """ - Set gradient overrides. - This will completely remove any previously set L_op/gradient overrides - - """ - self._lop_op = grad_overrides - self._lop_op_is_cached = False - self._lop_type = "grad" - self._lop_is_default = grad_overrides == "default" - - def set_lop_overrides(self, lop_overrides): - """ - Set L_op overrides - This will completely remove any previously set L_op/gradient overrides - """ - self._lop_op = lop_overrides - self._lop_op_is_cached = False - self._lop_type = "lop" - self._lop_is_default = lop_overrides == "default" - - def set_rop_overrides(self, rop_overrides): - """ - Set R_op overrides - This will completely remove any previously set R_op overrides + # Return a wrapper that combines connected and disconnected output gradients + def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]: + connected_output_grads = iter(rop_op(*inputs, **kwargs)) + all_output_grads = [] + for out_grad in output_grads: + if isinstance(out_grad.type, DisconnectedType): + # R_Op does not have DisconnectedType yet, None should be used instead + all_output_grads.append(None) + elif isinstance(out_grad.type, NullType): + all_output_grads.append(out_grad) + else: + all_output_grads.append(next(connected_output_grads)) + return all_output_grads - """ - self._rop_op = rop_overrides - self._rop_op_is_cached = False - self._rop_is_default = rop_overrides == "default" + self._rop_op_cache = wrapper + return wrapper def L_op(self, inputs, outputs, output_grads): - if not self._lop_op_is_cached: - self._recompute_lop_op() - inps = list(inputs) + list(outputs) + list(output_grads) - ret_ofg_l = self._lop_op(*inps, return_list=True) - ret_l = [ - ret_ofg if ov is None else ov - for ret_ofg, ov in zip(ret_ofg_l, self._lop_op_stypes_l) - ] - return ret_l + disconnected_output_grads = tuple( + isinstance(og.type, DisconnectedType) for og in output_grads + ) + lop_op = self._build_and_cache_lop_op(disconnected_output_grads) + return lop_op(*inputs, *outputs, *output_grads, return_list=True) def R_op(self, inputs, eval_points): - if not self._rop_op_is_cached: - self._recompute_rop_op() - ret_ofg_l = self._rop_op(*(list(inputs) + list(eval_points)), return_list=True) - ret_l = [ - ret_ofg if ov is None else ov - for ret_ofg, ov in zip(ret_ofg_l, self._rop_op_stypes_l) - ] - return ret_l + rop_op = self._build_and_cache_rop_op() + return rop_op(*inputs, *eval_points, return_list=True) def __call__(self, *inputs, **kwargs): # The user interface doesn't expect the shared variable inputs of the @@ -886,30 +786,9 @@ def connection_pattern(self, node): if self._connection_pattern is not None: return self._connection_pattern - inp_len = len(self.inner_inputs) - out_len = len(self.inner_outputs) - cpmat_self = io_connection_pattern(self.inner_inputs, self.inner_outputs) - - lop_op = self.get_lop_op() - cpmat_grad = io_connection_pattern( - lop_op.inner_inputs[inp_len:], lop_op.inner_outputs - ) - - # cpmat_self |= cpmat_grad.T - # cpmat_self &= out_is_disconnected - for i, t in enumerate(self._lop_op_stypes_l): - if t is not None: - if isinstance(t.type, DisconnectedType): - for o in range(out_len): - cpmat_self[i][o] = False - for o in range(out_len): - cpmat_self[i][o] |= cpmat_grad[o][i] - - # TODO in case DisconnectedType is implemented for R_op, - # self._rop_op_stypes_l self._rop_op should considered for - # connection_pattern - - return list(map(list, cpmat_self)) + ret = io_connection_pattern(self.inner_inputs, self.inner_outputs) + self._connection_pattern = ret + return ret def infer_shape(self, fgraph, node, shapes): # TODO: Use `fgraph.shape_feature` to do this instead. diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 78862de7e1..c8a896a185 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -929,12 +929,7 @@ def account_for(var): continue if ipt not in var_to_app_to_idx: - # This object here *must* be ordered, because - # we iterate over its keys when adding up the terms of the - # gradient on ipt. If it is a regular dict, the grad method - # will return something that is analytically correct, but - # whose order of doing additions depends on the memory - # location of the apply nodes. + # This object *must* be ordered for the grad graph to be deterministic var_to_app_to_idx[ipt] = {} app_to_idx = var_to_app_to_idx[ipt] if app not in app_to_idx: diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 6f8e8035d1..d71094bfed 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -8,10 +8,16 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.configdefaults import config -from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad +from pytensor.gradient import ( + DisconnectedType, + Rop, + disconnected_type, + grad, + verify_grad, +) from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph -from pytensor.graph.null_type import NullType +from pytensor.graph.null_type import NullType, null_type from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -22,7 +28,15 @@ from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.rewriting.shape import ShapeOptimizer from pytensor.tensor.shape import specify_shape -from pytensor.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors +from pytensor.tensor.type import ( + TensorType, + dscalars, + matrices, + matrix, + scalar, + vector, + vectors, +) from tests import unittest_tools from tests.graph.utils import MyVariable @@ -93,6 +107,20 @@ def test_size_changes(self, cls_ofg): assert res.shape == (2, 5) assert np.all(180.0 == res) + def test_overrides_deprecated_api(self): + inp = scalar("x") + out = inp + 1 + for kwarg in ("lop_overrides", "grad_overrides", "rop_overrides"): + with pytest.raises( + ValueError, match="'default' is no longer a valid value for overrides" + ): + OpFromGraph([inp], [out], **{kwarg: "default"}) + + with pytest.raises( + TypeError, match="Variables are no longer valid types for overrides" + ): + OpFromGraph([inp], [out], **{kwarg: null_type()}) + @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) @@ -181,8 +209,9 @@ def go(inps, gs): dedz = vector("dedz") op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz])) - op_mul = cls_ofg([x, y], [x * y], grad_overrides=go) - op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad) + with pytest.warns(FutureWarning, match="grad_overrides is deprecated"): + op_mul = cls_ofg([x, y], [x * y], grad_overrides=go) + op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad) # single override case (function or OfG instance) xx, yy = vector("xx"), vector("yy") @@ -209,9 +238,8 @@ def go2(inps, gs): w, b = vectors("wb") # we make the 3rd gradient default (no override) - op_linear = cls_ofg( - [x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"] - ) + with pytest.warns(FutureWarning, match="grad_overrides is deprecated"): + op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, None]) xx, ww, bb = vector("xx"), vector("yy"), vector("bb") zz = pt_sum(op_linear(xx, ww, bb)) dx, dw, db = grad(zz, [xx, ww, bb]) @@ -225,11 +253,14 @@ def go2(inps, gs): np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4) # NullType and DisconnectedType - op_linear2 = cls_ofg( - [x, w, b], - [x * w + b], - grad_overrides=[go1, NullType()(), DisconnectedType()()], - ) + with pytest.warns(FutureWarning, match="grad_overrides is deprecated"): + op_linear2 = cls_ofg( + [x, w, b], + [x * w + b], + grad_overrides=[go1, NullType()(), DisconnectedType()()], + # This is a fake override, so a fake connection_pattern must be provided as well + connection_pattern=[[True], [True], [False]], + ) zz2 = pt_sum(op_linear2(xx, ww, bb)) dx2, dw2, db2 = grad( zz2, @@ -293,6 +324,41 @@ def test_rop(self, cls_ofg): dvval2 = fn(xval, Wval, duval) np.testing.assert_array_almost_equal(dvval2, dvval, 4) + def test_rop_multiple_outputs(self): + a = vector() + M = matrix() + b = dot(a, M) + op_matmul = OpFromGraph([a, M], [b, -b]) + + x = vector() + W = matrix() + du = vector() + + xval = np.random.random((16,)).astype(config.floatX) + Wval = np.random.random((16, 16)).astype(config.floatX) + duval = np.random.random((16,)).astype(config.floatX) + + y = op_matmul(x, W)[0] + dv = Rop(y, x, du) + fn = function([x, W, du], dv) + result_dvval = fn(xval, Wval, duval) + expected_dvval = np.dot(duval, Wval) + np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) + + y = op_matmul(x, W)[1] + dv = Rop(y, x, du) + fn = function([x, W, du], dv) + result_dvval = fn(xval, Wval, duval) + expected_dvval = -np.dot(duval, Wval) + np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) + + y = pt.add(*op_matmul(x, W)) + dv = Rop(y, x, du) + fn = function([x, W, du], dv) + result_dvval = fn(xval, Wval, duval) + expected_dvval = np.zeros_like(np.dot(duval, Wval)) + np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) + @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) @@ -339,13 +405,14 @@ def f1(x, y): def f1_back(inputs, output_gradients): return [output_gradients[0], disconnected_type()] - op = cls_ofg( - inputs=[x, y], - outputs=[f1(x, y)], - grad_overrides=f1_back, - connection_pattern=[[True], [False]], # This is new - on_unused_input="ignore", - ) # This is new + with pytest.warns(FutureWarning, match="grad_overrides is deprecated"): + op = cls_ofg( + inputs=[x, y], + outputs=[f1(x, y)], + grad_overrides=f1_back, + connection_pattern=[[True], [False]], + on_unused_input="ignore", + ) c = op(x, y) @@ -585,6 +652,34 @@ def test_explicit_input_from_shared(self): out = test_ofg(y, y) assert out.eval() == 4 + def test_L_op_disconnected_output_grad(self): + x, y = dscalars("x", "y") + rng = np.random.default_rng(594) + point = list(rng.normal(size=(2,))) + + out1 = x + y + out2 = x * y + out3 = out1 * out2 # Create dependency between outputs + op = OpFromGraph([x, y], [out1, out2, out3]) + verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng) + verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng) + verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng) + verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng) + verify_grad(lambda x, y: op(x, y)[0], point, rng=rng) + verify_grad(lambda x, y: op(x, y)[1], point, rng=rng) + verify_grad(lambda x, y: op(x, y)[2], point, rng=rng) + + # Test disconnected graphs are handled correctly + op = OpFromGraph([x, y], [x**2, y**3]) + with pytest.warns(UserWarning): + grad_x_wrt_y = grad( + op(x, y)[0], + wrt=y, + return_disconnected="disconnected", + disconnected_inputs="warn", + ) + assert isinstance(grad_x_wrt_y.type, DisconnectedType) + def test_repeated_inputs(self): x = pt.dscalar("x") y = pt.dscalar("y")