Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions doc/extending/using_params.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
.. _extending_op_params:

===============
Using Op params
===============
================
Using COp params
================

The Op params is a facility to pass some runtime parameters to the
The COp params is a facility to pass some runtime parameters to the
code of an op without modifying it. It can enable a single instance
of C code to serve different needs and therefore reduce compilation.

Expand Down Expand Up @@ -53,7 +53,7 @@ following methods will be used for the type:
- :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <Type.values_eq>`

Additionally if you want to use your params with C code, you need to extend `COp`
Additionally, to use your params with C code, you need to extend `COp`
and implement the following methods:

- :meth:`c_declare <CLinkerType.c_declare>`
Expand All @@ -65,24 +65,24 @@ You can also define other convenience methods such as
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.


Registering the params with your Op
-----------------------------------
Registering the params with your COp
------------------------------------

To declare that your Op uses params you have to set the class
To declare that your `COp` uses params you have to set the class
attribute :attr:`params_type` to an instance of your params Type.

.. note::

If you want to have multiple parameters, PyTensor provides the convenient class
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available in both Python (as a Python object) and C code (as a struct).
one object that will be available to the C code (as a struct).

For example if we decide to use an int as the params the following
would be appropriate:

.. code-block:: python

class MyOp(Op):
class MyOp(COp):
params_type = Generic()

After that you need to define a :meth:`get_params` method on your
Expand Down Expand Up @@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
signature of :meth:`perform`. The new expected signature will have an
extra parameter at the end which corresponds to the params object.

.. warning::

If you do not account for this extra parameter, the code will fail
at runtime if it tries to run the python version.

Also, for the C code, the `sub` dictionary will contain an extra entry
The `sub` dictionary for `COp`s with params will contain an extra entry
`'params'` which will map to the variable name of the params object.
This is true for all methods that receive a `sub` parameter, so this
means that you can use your params in the :meth:`c_code <COp.c_code>`
Expand All @@ -131,7 +126,7 @@ A simple example
----------------

This is a simple example which uses a params object to pass a value.
This `Op` will multiply a scalar input by a fixed floating point value.
This `COp` will multiply a scalar input by a fixed floating point value.

Since the value in this case is a python float, we chose Generic as
the params type.
Expand All @@ -156,9 +151,10 @@ the params type.
inp = as_scalar(inp)
return Apply(self, [inp], [inp.type()])

def perform(self, node, inputs, output_storage, params):
# Here params is a python float so this is ok
output_storage[0][0] = inputs[0] * params
def perform(self, node, inputs, output_storage):
# Because params is a python float we can use `self.mul` directly.
# If it's something fancier, call `self.params_type.filter(self.get_params(node))`
output_storage[0][0] = inputs[0] * self.mul

def c_code(self, node, name, inputs, outputs, sub):
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
Expand All @@ -174,7 +170,7 @@ weights.

.. testcode::

from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.type import Generic
from pytensor.scalar import as_scalar

Expand Down
11 changes: 0 additions & 11 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from pytensor.configdefaults import config
from pytensor.graph.utils import (
MetaObject,
MethodNotDefined,
Scratchpad,
TestValueError,
ValidatingScratchpad,
Expand Down Expand Up @@ -151,16 +150,6 @@ def __init__(
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)

def run_params(self):
"""
Returns the params for the node, or NoParams if no params is set.

"""
try:
return self.op.get_params(self)
except MethodNotDefined:
return NoParams

def __getstate__(self):
d = self.__dict__
# ufunc don't pickle/unpickle well
Expand Down
65 changes: 9 additions & 56 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, NoParams, Variable
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.utils import (
MetaObject,
MethodNotDefined,
TestValueError,
add_tag_trace,
get_variable_trace_string,
)
from pytensor.link.c.params_type import Params, ParamsType


if TYPE_CHECKING:
Expand All @@ -37,10 +35,7 @@
ComputeMapType = dict[Variable, list[bool]]
InputStorageType = list[StorageCellType]
OutputStorageType = list[StorageCellType]
ParamsInputType = Optional[tuple[Any, ...]]
PerformMethodType = Callable[
[Apply, list[Any], OutputStorageType, ParamsInputType], None
]
PerformMethodType = Callable[[Apply, list[Any], OutputStorageType], None]
BasicThunkType = Callable[[], None]
ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
Expand Down Expand Up @@ -202,7 +197,6 @@ class Op(MetaObject):

itypes: Optional[Sequence["Type"]] = None
otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None

_output_type_depends_on_input_value = False
"""
Expand Down Expand Up @@ -426,7 +420,6 @@ def perform(
node: Apply,
inputs: Sequence[Any],
output_storage: OutputStorageType,
params: ParamsInputType = None,
) -> None:
"""Calculate the function on the inputs and put the variables in the output storage.

Expand All @@ -442,8 +435,6 @@ def perform(
these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists.
params
A tuple containing the values of each entry in :attr:`Op.__props__`.

Notes
-----
Expand Down Expand Up @@ -481,22 +472,6 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
"""
return True

def get_params(self, node: Apply) -> Params:
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if isinstance(self.params_type, ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")

def prepare_node(
self,
node: Apply,
Expand Down Expand Up @@ -538,34 +513,12 @@ def make_py_thunk(
else:
p = node.op.perform

params = node.run_params()

if params is NoParams:
# default arguments are stored in the closure of `rval`
@is_thunk_type
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r

else:
params_val = node.params_type.filter(params)

@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
params=params_val,
):
r = p(n, [x[0] for x in i], o, params)
for o in node.outputs:
compute_map[o][0] = True
return r
@is_thunk_type
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r

rval.inputs = node_input_storage
rval.outputs = node_output_storage
Expand Down Expand Up @@ -640,7 +593,7 @@ class _NoPythonOp(Op):

"""

def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this Op.")


Expand Down
13 changes: 11 additions & 2 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
io_toposort,
vars_between,
)
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker
from pytensor.link.c.cmodule import (
METH_VARARGS,
Expand Down Expand Up @@ -617,7 +618,12 @@ def fetch_variables(self):
# that needs it
self.node_params = dict()
for node in self.node_order:
params = node.run_params()
if not isinstance(node.op, CLinkerOp):
continue
try:
params = node.op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams:
# try to avoid creating more than one variable for the
# same params.
Expand Down Expand Up @@ -803,7 +809,10 @@ def code_gen(self):

sub = dict(failure_var=failure_var)

params = node.run_params()
try:
Copy link
Member Author

@ricardoV94 ricardoV94 Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We raise 4 lines above if it's not a ClinkerOp

params = op.get_params(node)
except MethodNotDefined:
params = NoParams
if params is not NoParams:
params_var = symbol[self.node_params[params]]

Expand Down
25 changes: 24 additions & 1 deletion pytensor/link/c/interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import typing
import warnings
from abc import abstractmethod
from typing import Callable
from typing import Callable, Optional

from pytensor.graph.basic import Apply, Constant
from pytensor.graph.utils import MethodNotDefined


if typing.TYPE_CHECKING:
from pytensor.link.c.params_type import Params, ParamsType


class CLinkerObject:
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""

Expand Down Expand Up @@ -172,6 +177,8 @@ def c_code_cache_version(self) -> tuple[int, ...]:
class CLinkerOp(CLinkerObject):
"""Interface definition for `Op` subclasses compiled by `CLinker`."""

params_type: Optional["ParamsType"] = None

@abstractmethod
def c_code(
self,
Expand Down Expand Up @@ -362,6 +369,22 @@ def c_cleanup_code_struct(self, node: Apply, name: str) -> str:
"""
return ""

def get_params(self, node: Apply) -> "Params":
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
if self.params_type is not None:
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
# Let's print missing attributes for debugging.
not_found = tuple(
field for field in wrapper.fields if not hasattr(self, field)
)
raise AttributeError(
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
)
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise MethodNotDefined("get_params")


class CLinkerType(CLinkerObject):
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ class _NoPythonCOp(COp):

"""

def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError("No Python implementation is provided by this COp.")


Expand All @@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):

"""

def perform(self, node, inputs, output_storage, params=None):
def perform(self, node, inputs, output_storage):
raise NotImplementedError(
"No Python implementation is provided by this ExternalCOp."
)
21 changes: 5 additions & 16 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply, NoParams
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
ret_sig = get_numba_type(node.outputs[0].type)

output_types = tuple(out.type for out in node.outputs)
params = node.run_params()

if params is not NoParams:
params_val = dict(node.params_type.filter(params))

def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs, params_val)
return outputs

else:

def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs

if n_outputs == 1:

Expand Down
2 changes: 1 addition & 1 deletion pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def make_node(self, value: Variable, *conds: Variable):
[value.type()],
)

def perform(self, node, inputs, outputs, params):
def perform(self, node, inputs, outputs):
(out,) = outputs
val, *conds = inputs
out[0] = val
Expand Down
Loading