22import pickle
33from collections .abc import Callable
44from copy import copy
5+ from functools import singledispatch
56from textwrap import dedent , indent
67from typing import Any
78
@@ -168,25 +169,10 @@ def impl(rng):
168169 return impl
169170
170171
171- @numba_funcify .register (ptr .RandomVariable )
172- def numba_funcify_RandomVariable (op : RandomVariable , node , ** kwargs ):
173- _ , size , _ , * args = node .inputs
174- # None sizes are represented as empty tuple for the time being
175- # https://github.com/pymc-devs/pytensor/issues/568
176- [size_len ] = size .type .shape
177- size_is_None = size_len == 0
178-
179- inplace = op .inplace
180-
181- if op .ndim_supp > 0 :
182- raise NotImplementedError ("Multivariate random variables not supported yet" )
183-
184- # if any(ndim_param > 0 for ndim_param in op.ndims_params):
185- # raise NotImplementedError(
186- # "Random variables with non scalar core inputs not supported yet"
187- # )
172+ @singledispatch
173+ def core_rv_fn (op : Op ):
174+ """Return the core function for a random variable operation."""
188175
189- # TODO: Use dispatch, so users can define the core case
190176 # Use string repr for default like below
191177 # inner_code = dedent(f"""
192178 # @numba_basic.numba_njit
@@ -197,15 +183,67 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
197183 # exec(inner_code)
198184 # scalar_op_fn = locals()['scalar_op_fn']
199185
200- # @numba_basic.numba_njit
201- # def core_op_fn(rng, mu, scale):
202- # return rng.normal(mu, scale)
186+ raise NotImplementedError ()
187+
203188
189+ @core_rv_fn .register (ptr .NormalRV )
190+ def core_NormalRV (op ):
204191 @numba_basic .numba_njit
205- def core_op_fn (rng , p ):
192+ def random_fn (rng , mu , scale , out ):
193+ out [...] = rng .normal (mu , scale )
194+
195+ random_fn .handles_out = True
196+ return random_fn
197+
198+
199+ @core_rv_fn .register (ptr .CategoricalRV )
200+ def core_CategoricalRV (op ):
201+ @numba_basic .numba_njit
202+ def random_fn (rng , p , out ):
206203 unif_sample = rng .uniform (0 , 1 )
207- return np .searchsorted (np .cumsum (p ), unif_sample )
204+ # TODO: Check if LLVM can lift constant cumsum(p) out of the loop
205+ out [...] = np .searchsorted (np .cumsum (p ), unif_sample )
206+
207+ random_fn .handles_out = True
208+ return random_fn
209+
210+
211+ @core_rv_fn .register (ptr .MvNormalRV )
212+ def core_MvNormalRV (op ):
213+ @numba .njit
214+ def random_fn (rng , mean , cov , out ):
215+ chol = np .linalg .cholesky (cov )
216+ stdnorm = rng .normal (size = cov .shape [- 1 ])
217+ # np.dot(chol, stdnorm, out=out)
218+ # out[...] += mean
219+ out [...] = mean + np .dot (chol , stdnorm )
208220
221+ random_fn .handles_out = True
222+ return random_fn
223+
224+
225+ @numba_funcify .register (ptr .RandomVariable )
226+ def numba_funcify_RandomVariable (op : RandomVariable , node , ** kwargs ):
227+ _ , size , _ , * args = node .inputs
228+ # None sizes are represented as empty tuple for the time being
229+ # https://github.com/pymc-devs/pytensor/issues/568
230+ [size_len ] = size .type .shape
231+ size_is_None = size_len == 0
232+
233+ inplace = op .inplace
234+
235+ # TODO: Add core_shape to node.inputs
236+ if op .ndim_supp > 0 :
237+ raise NotImplementedError ("Multivariate RandomVariable not implemented yet" )
238+
239+ # TODO: Create a wrapper (string processing?) that takes a core function without outputs
240+ # and saves those outputs in the variables passed by `_vectorized`
241+ core_op_fn = core_rv_fn (op )
242+ if not getattr (core_op_fn , "handles_out" , False ):
243+ # core_op_fn = store_core_outputs(op, core_op_fn)
244+ raise NotImplementedError ()
245+
246+ # TODO: Refactor this code, it's the same with Elemwise
209247 batch_ndim = node .default_output ().ndim - op .ndim_supp
210248 output_bc_patterns = ((False ,) * batch_ndim ,)
211249 input_bc_patterns = tuple (
@@ -234,12 +272,14 @@ def random_wrapper(rng, size, dtype, *inputs):
234272 inplace_pattern_enc ,
235273 (rng ,),
236274 inputs ,
237- None if size_is_None else numba_ndarray .to_fixed_tuple (size , size_len ),
275+ ((),), # TODO: correct core_shapes
276+ None
277+ if size_is_None
278+ else numba_ndarray .to_fixed_tuple (size , size_len ), # size
238279 )
239280 return rng , draws
240281
241282 def random (rng , size , dtype , * inputs ):
242- # TODO: Add code that will be tested for coverage
243283 pass
244284
245285 @overload (random )
@@ -330,35 +370,6 @@ def body_fn(a):
330370 )
331371
332372
333- # @numba_funcify.register(ptr.CategoricalRV)
334- def numba_funcify_CategoricalRV (op , node , ** kwargs ):
335- out_dtype = node .outputs [1 ].type .numpy_dtype
336- size_len = int (get_vector_length (node .inputs [1 ]))
337- p_ndim = node .inputs [- 1 ].ndim
338-
339- @numba_basic .numba_njit
340- def categorical_rv (rng , size , dtype , p ):
341- if not size_len :
342- size_tpl = p .shape [:- 1 ]
343- else :
344- size_tpl = numba_ndarray .to_fixed_tuple (size , size_len )
345- p = np .broadcast_to (p , size_tpl + p .shape [- 1 :])
346-
347- # Workaround https://github.com/numba/numba/issues/8975
348- if not size_len and p_ndim == 1 :
349- unif_samples = np .asarray (np .random .uniform (0 , 1 ))
350- else :
351- unif_samples = np .random .uniform (0 , 1 , size_tpl )
352-
353- res = np .empty (size_tpl , dtype = out_dtype )
354- for idx in np .ndindex (* size_tpl ):
355- res [idx ] = np .searchsorted (np .cumsum (p [idx ]), unif_samples [idx ])
356-
357- return (rng , res )
358-
359- return categorical_rv
360-
361-
362373@numba_funcify .register (ptr .DirichletRV )
363374def numba_funcify_DirichletRV (op , node , ** kwargs ):
364375 out_dtype = node .outputs [1 ].type .numpy_dtype
0 commit comments