33from contextlib import contextmanager
44from functools import singledispatch
55from textwrap import dedent
6- from typing import Union
6+ from typing import TYPE_CHECKING , Callable , Optional , Union , cast
77
88import numba
99import numba .np .unsafe .ndarray as numba_ndarray
2222from pytensor .compile .ops import DeepCopyOp
2323from pytensor .graph .basic import Apply , NoParams
2424from pytensor .graph .fg import FunctionGraph
25+ from pytensor .graph .op import Op
2526from pytensor .graph .type import Type
2627from pytensor .ifelse import IfElse
2728from pytensor .link .utils import (
4849from pytensor .tensor .type_other import MakeSlice , NoneConst
4950
5051
52+ if TYPE_CHECKING :
53+ from pytensor .graph .op import StorageMapType
54+
55+
5156def numba_njit (* args , ** kwargs ):
5257
5358 if len (args ) > 0 and callable (args [0 ]):
@@ -339,9 +344,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
339344 return data
340345
341346
347+ def numba_funcify (obj , node = None , storage_map = None , ** kwargs ) -> Callable :
348+ """Convert `obj` to a Numba-JITable object."""
349+ return _numba_funcify (obj , node = node , storage_map = storage_map , ** kwargs )
350+
351+
342352@singledispatch
343- def numba_funcify (op , node = None , storage_map = None , ** kwargs ):
344- """Create a Numba compatible function from an PyTensor `Op`."""
353+ def _numba_funcify (
354+ obj ,
355+ node : Optional [Apply ] = None ,
356+ storage_map : Optional ["StorageMapType" ] = None ,
357+ ** kwargs ,
358+ ) -> Callable :
359+ r"""Dispatch on PyTensor object types to perform Numba conversions.
360+
361+ Arguments
362+ ---------
363+ obj
364+ The object used to determine the appropriate conversion function based
365+ on its type. This is generally an `Op` instance, but `FunctionGraph`\s
366+ are also supported.
367+ node
368+ When `obj` is an `Op`, this value should be the corresponding `Apply` node.
369+ storage_map
370+ A storage map with, for example, the constant and `SharedVariable` values
371+ of the graph being converted.
372+
373+ Returns
374+ -------
375+ A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
376+
377+ """
378+
379+
380+ @_numba_funcify .register (Op )
381+ def numba_funcify_perform (op , node , storage_map = None , ** kwargs ) -> Callable :
382+ """Create a Numba compatible function from an PyTensor `Op.perform`."""
345383
346384 warnings .warn (
347385 f"Numba will use object mode to run { op } 's perform method" ,
@@ -392,10 +430,10 @@ def perform(*inputs):
392430 ret = py_perform_return (inputs )
393431 return ret
394432
395- return perform
433+ return cast ( Callable , perform )
396434
397435
398- @numba_funcify .register (OpFromGraph )
436+ @_numba_funcify .register (OpFromGraph )
399437def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
400438
401439 _ = kwargs .pop ("storage_map" , None )
@@ -417,7 +455,7 @@ def opfromgraph(*inputs):
417455 return opfromgraph
418456
419457
420- @numba_funcify .register (FunctionGraph )
458+ @_numba_funcify .register (FunctionGraph )
421459def numba_funcify_FunctionGraph (
422460 fgraph ,
423461 node = None ,
@@ -525,9 +563,9 @@ def {fn_name}({", ".join(input_names)}):
525563 return subtensor_def_src
526564
527565
528- @numba_funcify .register (Subtensor )
529- @numba_funcify .register (AdvancedSubtensor )
530- @numba_funcify .register (AdvancedSubtensor1 )
566+ @_numba_funcify .register (Subtensor )
567+ @_numba_funcify .register (AdvancedSubtensor )
568+ @_numba_funcify .register (AdvancedSubtensor1 )
531569def numba_funcify_Subtensor (op , node , ** kwargs ):
532570
533571 subtensor_def_src = create_index_func (
@@ -543,8 +581,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
543581 return numba_njit (subtensor_fn )
544582
545583
546- @numba_funcify .register (IncSubtensor )
547- @numba_funcify .register (AdvancedIncSubtensor )
584+ @_numba_funcify .register (IncSubtensor )
585+ @_numba_funcify .register (AdvancedIncSubtensor )
548586def numba_funcify_IncSubtensor (op , node , ** kwargs ):
549587
550588 incsubtensor_def_src = create_index_func (
@@ -560,7 +598,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
560598 return numba_njit (incsubtensor_fn )
561599
562600
563- @numba_funcify .register (AdvancedIncSubtensor1 )
601+ @_numba_funcify .register (AdvancedIncSubtensor1 )
564602def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
565603 inplace = op .inplace
566604 set_instead_of_inc = op .set_instead_of_inc
@@ -593,7 +631,7 @@ def advancedincsubtensor1(x, vals, idxs):
593631 return advancedincsubtensor1
594632
595633
596- @numba_funcify .register (DeepCopyOp )
634+ @_numba_funcify .register (DeepCopyOp )
597635def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
598636
599637 # Scalars are apparently returned as actual Python scalar types and not
@@ -615,26 +653,26 @@ def deepcopyop(x):
615653 return deepcopyop
616654
617655
618- @numba_funcify .register (MakeSlice )
619- def numba_funcify_MakeSlice (op , ** kwargs ):
656+ @_numba_funcify .register (MakeSlice )
657+ def numba_funcify_MakeSlice (op , node , ** kwargs ):
620658 @numba_njit
621659 def makeslice (* x ):
622660 return slice (* x )
623661
624662 return makeslice
625663
626664
627- @numba_funcify .register (Shape )
628- def numba_funcify_Shape (op , ** kwargs ):
665+ @_numba_funcify .register (Shape )
666+ def numba_funcify_Shape (op , node , ** kwargs ):
629667 @numba_njit (inline = "always" )
630668 def shape (x ):
631669 return np .asarray (np .shape (x ))
632670
633671 return shape
634672
635673
636- @numba_funcify .register (Shape_i )
637- def numba_funcify_Shape_i (op , ** kwargs ):
674+ @_numba_funcify .register (Shape_i )
675+ def numba_funcify_Shape_i (op , node , ** kwargs ):
638676 i = op .i
639677
640678 @numba_njit (inline = "always" )
@@ -664,8 +702,8 @@ def codegen(context, builder, signature, args):
664702 return sig , codegen
665703
666704
667- @numba_funcify .register (Reshape )
668- def numba_funcify_Reshape (op , ** kwargs ):
705+ @_numba_funcify .register (Reshape )
706+ def numba_funcify_Reshape (op , node , ** kwargs ):
669707 ndim = op .ndim
670708
671709 if ndim == 0 :
@@ -687,7 +725,7 @@ def reshape(x, shape):
687725 return reshape
688726
689727
690- @numba_funcify .register (SpecifyShape )
728+ @_numba_funcify .register (SpecifyShape )
691729def numba_funcify_SpecifyShape (op , node , ** kwargs ):
692730 shape_inputs = node .inputs [1 :]
693731 shape_input_names = ["shape_" + str (i ) for i in range (len (shape_inputs ))]
@@ -734,7 +772,7 @@ def inputs_cast(x):
734772 return inputs_cast
735773
736774
737- @numba_funcify .register (Dot )
775+ @_numba_funcify .register (Dot )
738776def numba_funcify_Dot (op , node , ** kwargs ):
739777 # Numba's `np.dot` does not support integer dtypes, so we need to cast to
740778 # float.
@@ -749,7 +787,7 @@ def dot(x, y):
749787 return dot
750788
751789
752- @numba_funcify .register (Softplus )
790+ @_numba_funcify .register (Softplus )
753791def numba_funcify_Softplus (op , node , ** kwargs ):
754792
755793 x_dtype = np .dtype (node .inputs [0 ].dtype )
@@ -768,7 +806,7 @@ def softplus(x):
768806 return softplus
769807
770808
771- @numba_funcify .register (Cholesky )
809+ @_numba_funcify .register (Cholesky )
772810def numba_funcify_Cholesky (op , node , ** kwargs ):
773811 lower = op .lower
774812
@@ -804,7 +842,7 @@ def cholesky(a):
804842 return cholesky
805843
806844
807- @numba_funcify .register (Solve )
845+ @_numba_funcify .register (Solve )
808846def numba_funcify_Solve (op , node , ** kwargs ):
809847
810848 assume_a = op .assume_a
@@ -851,7 +889,7 @@ def solve(a, b):
851889 return solve
852890
853891
854- @numba_funcify .register (BatchedDot )
892+ @_numba_funcify .register (BatchedDot )
855893def numba_funcify_BatchedDot (op , node , ** kwargs ):
856894 dtype = node .outputs [0 ].type .numpy_dtype
857895
@@ -872,7 +910,7 @@ def batched_dot(x, y):
872910# optimizations are apparently already performed by Numba
873911
874912
875- @numba_funcify .register (IfElse )
913+ @_numba_funcify .register (IfElse )
876914def numba_funcify_IfElse (op , ** kwargs ):
877915 n_outs = op .n_outs
878916
0 commit comments