-
Notifications
You must be signed in to change notification settings - Fork 143
Description
Describe the issue:
When compiling a function to a non-default backend (JAX, numba), default re-writes are still applied to the inner function under specific conditions. This can result in invalid graphs when the default rewrites don't apply, for example when dot
is rewritten to Dot22
. This appears to only happen when 1) there are sit-sot inputs to scan, and 2) the length of the scan is greater than 1.
I had previously mentioned this to @ricardoV94 because i thought it was a simple case of mode
not propogating to the inner function, but the bug seems a bit more subtle than that.
Reproducable code example:
import pytensor
import pytensor.tensor as pt
import numpy as np
A = pt.matrix('X')
B = pt.matrix('B')
# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, non_sequences=[A, B], n_steps=1)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))
# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, non_sequences=[A, B], n_steps=2)
f = pytensor.function([A, B], out, mode='JAX')
print(f(np.eye(3), np.eye(3)))
# Works, no error
out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=1)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))
# Fails
out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
f = pytensor.function([A, B], out, mode='JAX')
f(np.eye(3), np.eye(3))
Error message:
NotImplementedError Traceback (most recent call last)
Cell In[36], line 19
17 # Fails
18 out, _ = pytensor.scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
---> 19 f = pytensor.function([A, B], out, mode='JAX')
20 f(np.eye(3), np.eye(3))
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/__init__.py:315, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
309 fn = orig_function(
310 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
311 )
312 else:
313 # note: pfunc will also call orig_function -- orig_function is
314 # a choke point that all compilation must pass through
--> 315 fn = pfunc(
316 params=inputs,
317 outputs=outputs,
318 mode=mode,
319 updates=updates,
320 givens=givens,
321 no_default_updates=no_default_updates,
322 accept_inplace=accept_inplace,
323 name=name,
324 rebuild_strict=rebuild_strict,
325 allow_input_downcast=allow_input_downcast,
326 on_unused_input=on_unused_input,
327 profile=profile,
328 output_keys=output_keys,
329 )
330 return fn
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:468, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
454 profile = ProfileStats(message=profile)
456 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
457 params,
458 outputs,
(...)
465 fgraph=fgraph,
466 )
--> 468 return orig_function(
469 inputs,
470 cloned_outputs,
471 mode,
472 accept_inplace=accept_inplace,
473 name=name,
474 profile=profile,
475 on_unused_input=on_unused_input,
476 output_keys=output_keys,
477 fgraph=fgraph,
478 )
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/types.py:1756, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1744 m = Maker(
1745 inputs,
1746 outputs,
(...)
1753 fgraph=fgraph,
1754 )
1755 with config.change_flags(compute_test_value="off"):
-> 1756 fn = m.create(defaults)
1757 finally:
1758 t2 = time.perf_counter()
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/compile/function/types.py:1649, in FunctionMaker.create(self, input_storage, storage_map)
1646 start_import_time = pytensor.link.c.cmodule.import_time
1648 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1649 _fn, _i, _o = self.linker.make_thunk(
1650 input_storage=input_storage_lists, storage_map=storage_map
1651 )
1653 end_linker = time.perf_counter()
1655 linker_time = end_linker - start_linker
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
247 def make_thunk(
248 self,
249 input_storage: Optional["InputStorageType"] = None,
(...)
252 **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254 return self.make_all(
255 input_storage=input_storage,
256 output_storage=output_storage,
257 storage_map=storage_map,
258 )[:3]
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
694 for k in storage_map:
695 compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
698 compute_map, nodes, input_storage, output_storage, storage_map
699 )
701 computed, last_user = gc_helper(nodes)
703 if self.allow_gc:
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
657 thunks = []
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/linker.py:59, in JAXLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
54 input_storage[input_storage_idx] = new_inp_storage
55 fgraph.remove_input(
56 fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
57 )
---> 59 return jax_funcify(
60 fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
61 )
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
736 body_assigns = []
737 for node in order:
--> 738 compiled_func = op_conversion_fn(
739 node.op, node=node, storage_map=storage_map, **kwargs
740 )
742 # Create a local alias with a unique name
743 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:23, in jax_funcify_Scan(op, **kwargs)
21 rewriter = op.mode_instance.optimizer
22 rewriter(op.fgraph)
---> 23 scan_inner_func = jax_funcify(op.fgraph, **kwargs)
25 def scan(*outer_inputs):
26 # Extract JAX scan inputs
27 outer_inputs = list(outer_inputs)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
736 body_assigns = []
737 for node in order:
--> 738 compiled_func = op_conversion_fn(
739 node.op, node=node, storage_map=storage_map, **kwargs
740 )
742 # Create a local alias with a unique name
743 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc-experimental/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:41, in jax_funcify(op, node, storage_map, **kwargs)
38 @singledispatch
39 def jax_funcify(op, node=None, storage_map=None, **kwargs):
40 """Create a JAX compatible function from an PyTensor `Op`."""
---> 41 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: Dot22
PyTensor version information:
'2.14.2'
Context for the issue:
It adds an obnoxious keyword to build_statespace_graph
in pymc_experimental.statespace
models, and makes it cumbersome for users to quickly try different samplers (they have to remember to change the mode in 2 places)