Skip to content

BUG: Default rewrites applied to scans with sit-sot inputs regardless of backend #426

@jessegrabowski

Description

@jessegrabowski

Describe the issue:

When compiling a function to a non-default backend (JAX, numba), default re-writes are still applied to the inner function under specific conditions. This can result in invalid graphs when the default rewrites don't apply, for example when dot is rewritten to Dot22. This appears to only happen when 1) there are sit-sot inputs to scan, and 2) the length of the scan is greater than 1.

I had previously mentioned this to @ricardoV94 because i thought it was a simple case of mode not propogating to the inner function, but the bug seems a bit more subtle than that.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

A = pt.matrix('X')
B = pt.matrix('B')

# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, non_sequences=[A, B], n_steps=1)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))

# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, non_sequences=[A, B], n_steps=2)
f = pytensor.function([A, B], out, mode='JAX')
print(f(np.eye(3), np.eye(3)))

# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=1)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))

# Fails
out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))

Error message:

NotImplementedError                       Traceback (most recent call last)
Cell In[36], line 19
     17 # Fails
     18 out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
---> 19 f = pytensor.function([A, B], out, mode='JAX')
     20 f(np.eye(3), np.eye(3))

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/__init__.py:315, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    309     fn = orig_function(
    310         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    311     )
    312 else:
    313     # note: pfunc will also call orig_function -- orig_function is
    314     #      a choke point that all compilation must pass through
--> 315     fn = pfunc(
    316         params=inputs,
    317         outputs=outputs,
    318         mode=mode,
    319         updates=updates,
    320         givens=givens,
    321         no_default_updates=no_default_updates,
    322         accept_inplace=accept_inplace,
    323         name=name,
    324         rebuild_strict=rebuild_strict,
    325         allow_input_downcast=allow_input_downcast,
    326         on_unused_input=on_unused_input,
    327         profile=profile,
    328         output_keys=output_keys,
    329     )
    330 return fn

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:468, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    454     profile = ProfileStats(message=profile)
    456 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    457     params,
    458     outputs,
   (...)
    465     fgraph=fgraph,
    466 )
--> 468 return orig_function(
    469     inputs,
    470     cloned_outputs,
    471     mode,
    472     accept_inplace=accept_inplace,
    473     name=name,
    474     profile=profile,
    475     on_unused_input=on_unused_input,
    476     output_keys=output_keys,
    477     fgraph=fgraph,
    478 )

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/types.py:1756, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1744     m = Maker(
   1745         inputs,
   1746         outputs,
   (...)
   1753         fgraph=fgraph,
   1754     )
   1755     with config.change_flags(compute_test_value="off"):
-> 1756         fn = m.create(defaults)
   1757 finally:
   1758     t2 = time.perf_counter()

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/types.py:1649, in FunctionMaker.create(self, input_storage, storage_map)
   1646 start_import_time = pytensor.link.c.cmodule.import_time
   1648 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1649     _fn, _i, _o = self.linker.make_thunk(
   1650         input_storage=input_storage_lists, storage_map=storage_map
   1651     )
   1653 end_linker = time.perf_counter()
   1655 linker_time = end_linker - start_linker

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    247 def make_thunk(
    248     self,
    249     input_storage: Optional["InputStorageType"] = None,
   (...)
    252     **kwargs,
    253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
    255         input_storage=input_storage,
    256         output_storage=output_storage,
    257         storage_map=storage_map,
    258     )[:3]

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    694 for k in storage_map:
    695     compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
    698     compute_map, nodes, input_storage, output_storage, storage_map
    699 )
    701 computed, last_user = gc_helper(nodes)
    703 if self.allow_gc:

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    657 thunks = []

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/linker.py:59, in JAXLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
     54         input_storage[input_storage_idx] = new_inp_storage
     55         fgraph.remove_input(
     56             fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
     57         )
---> 59 return jax_funcify(
     60     fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
     61 )

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     44 @jax_funcify.register(FunctionGraph)
     45 def jax_funcify_FunctionGraph(
     46     fgraph,
   (...)
     49     **kwargs,
     50 ):
---> 51     return fgraph_to_python(
     52         fgraph,
     53         jax_funcify,
     54         type_conversion_fn=jax_typify,
     55         fgraph_name=fgraph_name,
     56         **kwargs,
     57     )

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    736 body_assigns = []
    737 for node in order:
--> 738     compiled_func = op_conversion_fn(
    739         node.op, node=node, storage_map=storage_map, **kwargs
    740     )
    742     # Create a local alias with a unique name
    743     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:23, in jax_funcify_Scan(op, **kwargs)
     21 rewriter = op.mode_instance.optimizer
     22 rewriter(op.fgraph)
---> 23 scan_inner_func = jax_funcify(op.fgraph, **kwargs)
     25 def scan(*outer_inputs):
     26     # Extract JAX scan inputs
     27     outer_inputs = list(outer_inputs)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     44 @jax_funcify.register(FunctionGraph)
     45 def jax_funcify_FunctionGraph(
     46     fgraph,
   (...)
     49     **kwargs,
     50 ):
---> 51     return fgraph_to_python(
     52         fgraph,
     53         jax_funcify,
     54         type_conversion_fn=jax_typify,
     55         fgraph_name=fgraph_name,
     56         **kwargs,
     57     )

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    736 body_assigns = []
    737 for node in order:
--> 738     compiled_func = op_conversion_fn(
    739         node.op, node=node, storage_map=storage_map, **kwargs
    740     )
    742     # Create a local alias with a unique name
    743     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:41, in jax_funcify(op, node, storage_map, **kwargs)
     38 @singledispatch
     39 def jax_funcify(op, node=None, storage_map=None, **kwargs):
     40     """Create a JAX compatible function from an PyTensor `Op`."""
---> 41     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: Dot22

PyTensor version information:

'2.14.2'

Context for the issue:

It adds an obnoxious keyword to build_statespace_graph in pymc_experimental.statespace models, and makes it cumbersome for users to quickly try different samplers (they have to remember to change the mode in 2 places)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions