2626 vectorize_graph ,
2727)
2828from pytensor .scan import map as scan_map
29- from pytensor .tensor import TensorVariable
29+ from pytensor .tensor import TensorType , TensorVariable
3030from pytensor .tensor .elemwise import Elemwise
3131from pytensor .tensor .shape import Shape
3232from pytensor .tensor .special import log_softmax
@@ -379,41 +379,36 @@ def transform_input(inputs):
379379
380380 rv_dict = {}
381381 rv_dims = {}
382- for seed , rv in zip (seeds , vars_to_recover ):
382+ for seed , marginalized_rv in zip (seeds , vars_to_recover ):
383383 supported_dists = (Bernoulli , Categorical , DiscreteUniform )
384- if not isinstance (rv .owner .op , supported_dists ):
384+ if not isinstance (marginalized_rv .owner .op , supported_dists ):
385385 raise NotImplementedError (
386- f"RV with distribution { rv .owner .op } cannot be recovered. "
386+ f"RV with distribution { marginalized_rv .owner .op } cannot be recovered. "
387387 f"Supported distribution include { supported_dists } "
388388 )
389389
390390 m = self .clone ()
391- rv = m .vars_to_clone [rv ]
392- m .unmarginalize ([rv ])
393- dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
394- joint_logps = m .logp (vars = dependent_vars + [ rv ] , sum = False )
391+ marginalized_rv = m .vars_to_clone [marginalized_rv ]
392+ m .unmarginalize ([marginalized_rv ])
393+ dependent_vars = find_conditional_dependent_rvs (marginalized_rv , m .basic_RVs )
394+ joint_logps = m .logp (vars = [ marginalized_rv ] + dependent_vars , sum = False )
395395
396- marginalized_value = m .rvs_to_values [rv ]
396+ marginalized_value = m .rvs_to_values [marginalized_rv ]
397397 other_values = [v for v in m .value_vars if v is not marginalized_value ]
398398
399399 # Handle batch dims for marginalized value and its dependent RVs
400- joint_logp = joint_logps [- 1 ]
401- for dv in joint_logps [:- 1 ]:
402- dbcast = dv .type .broadcastable
403- mbcast = marginalized_value .type .broadcastable
404- mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
405- values_axis_bcast = [
406- i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
407- ]
408- joint_logp += dv .sum (values_axis_bcast )
400+ marginalized_logp , * dependent_logps = joint_logps
401+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
402+ marginalized_rv .type , dependent_logps
403+ )
409404
410- rv_shape = constant_fold (tuple (rv .shape ))
411- rv_domain = get_domain_of_finite_discrete_rv (rv )
405+ rv_shape = constant_fold (tuple (marginalized_rv .shape ))
406+ rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
412407 rv_domain_tensor = pt .moveaxis (
413408 pt .full (
414409 (* rv_shape , len (rv_domain )),
415410 rv_domain ,
416- dtype = rv .dtype ,
411+ dtype = marginalized_rv .dtype ,
417412 ),
418413 - 1 ,
419414 0 ,
@@ -429,7 +424,7 @@ def transform_input(inputs):
429424 joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
430425 if return_samples :
431426 sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
432- if isinstance (rv .owner .op , DiscreteUniform ):
427+ if isinstance (marginalized_rv .owner .op , DiscreteUniform ):
433428 sample_rv_outs += rv_domain [0 ]
434429
435430 rv_loglike_fn = compile_pymc (
@@ -454,18 +449,20 @@ def transform_input(inputs):
454449 logps , samples = zip (* logvs )
455450 logps = np .array (logps )
456451 samples = np .array (samples )
457- rv_dict [rv .name ] = samples .reshape (
452+ rv_dict [marginalized_rv .name ] = samples .reshape (
458453 tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
459454 )
460455 else :
461456 logps = np .array (logvs )
462457
463- rv_dict ["lp_" + rv .name ] = logps .reshape (
458+ rv_dict ["lp_" + marginalized_rv .name ] = logps .reshape (
464459 tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
465460 )
466- if rv .name in m .named_vars_to_dims :
467- rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
468- rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
461+ if marginalized_rv .name in m .named_vars_to_dims :
462+ rv_dims [marginalized_rv .name ] = list (m .named_vars_to_dims [marginalized_rv .name ])
463+ rv_dims ["lp_" + marginalized_rv .name ] = rv_dims [marginalized_rv .name ] + [
464+ "lp_" + marginalized_rv .name + "_dim"
465+ ]
469466
470467 coords , dims = coords_and_dims_for_inferencedata (self )
471468 dims .update (rv_dims )
@@ -645,6 +642,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
645642 raise NotImplementedError (f"Cannot compute domain for op { op } " )
646643
647644
645+ def _add_reduce_batch_dependent_logps (
646+ marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
647+ ):
648+ """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
649+
650+ mbcast = marginalized_type .broadcastable
651+ reduced_logps = []
652+ for dependent_logp in dependent_logps :
653+ dbcast = dependent_logp .type .broadcastable
654+ dim_diff = len (dbcast ) - len (mbcast )
655+ mbcast_aligned = (True ,) * dim_diff + mbcast
656+ vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , dbcast )) if m and not v ]
657+ reduced_logps .append (dependent_logp .sum (vbcast_axis ))
658+ return pt .add (* reduced_logps )
659+
660+
648661@_logprob .register (FiniteDiscreteMarginalRV )
649662def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
650663 # Clone the inner RV graph of the Marginalized RV
@@ -660,17 +673,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
660673 logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
661674
662675 # Reduce logp dimensions corresponding to broadcasted variables
663- joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
664- for inner_rv , inner_value in inner_rvs_to_values .items ():
665- if inner_rv is marginalized_rv :
666- continue
667- vbcast = inner_value .type .broadcastable
668- mbcast = marginalized_rv .type .broadcastable
669- mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
670- values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
671- joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
672-
673- # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
676+ marginalized_logp = logps_dict .pop (inner_rvs_to_values [marginalized_rv ])
677+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
678+ marginalized_rv .type , logps_dict .values ()
679+ )
680+
681+ # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
674682 # values of the marginalized RV
675683 # Some inputs are not root inputs (such as transformed projections of value variables)
676684 # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -698,6 +706,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
698706 )
699707
700708 # Arbitrary cutoff to switch to Scan implementation to keep graph size under control
709+ # TODO: Try vectorize here
701710 if len (marginalized_rv_domain ) <= 10 :
702711 joint_logps = [
703712 joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
0 commit comments