1313from pymc .distributions .transforms import Chain
1414from pymc .logprob .transforms import IntervalTransform
1515from pymc .model import Model
16- from pymc .pytensorf import compile_pymc , constant_fold , toposort_replace
16+ from pymc .pytensorf import compile_pymc , constant_fold
1717from pymc .util import RandomState , _get_seeds_per_chain , treedict
1818from pytensor .graph import FunctionGraph , clone_replace
19- from pytensor .graph .basic import truncated_graph_inputs , Constant , ancestors
2019from pytensor .graph .replace import vectorize_graph
21- from pytensor .tensor import TensorVariable , extract_constant
20+ from pytensor .tensor import TensorVariable
2221from pytensor .tensor .special import log_softmax
2322
2423__all__ = ["MarginalModel" , "marginalize" ]
2524
2625from pymc_experimental .distributions import DiscreteMarkovChain
27- from pymc_experimental .model .marginal .distributions import FiniteDiscreteMarginalRV , DiscreteMarginalMarkovChainRV , \
28- get_domain_of_finite_discrete_rv , _add_reduce_batch_dependent_logps
29- from pymc_experimental .model .marginal .graph_analysis import find_conditional_input_rvs , is_conditional_dependent , \
30- find_conditional_dependent_rvs , subgraph_dim_connection , collect_shared_vars
26+ from pymc_experimental .model .marginal .distributions import (
27+ DiscreteMarginalMarkovChainRV ,
28+ FiniteDiscreteMarginalRV ,
29+ _add_reduce_batch_dependent_logps ,
30+ get_domain_of_finite_discrete_rv ,
31+ )
32+ from pymc_experimental .model .marginal .graph_analysis import (
33+ collect_shared_vars ,
34+ find_conditional_dependent_rvs ,
35+ find_conditional_input_rvs ,
36+ is_conditional_dependent ,
37+ subgraph_dim_connection ,
38+ )
3139
3240ModelRVs = TensorVariable | Sequence [TensorVariable ] | str | Sequence [str ]
3341
@@ -537,10 +545,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
537545
538546
539547def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
540- # TODO: This should eventually be integrated in a more general routine that can
541- # identify other types of supported marginalization, of which finite discrete
542- # RVs is just one
543-
544548 dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
545549 if not dependent_rvs :
546550 raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
@@ -552,7 +556,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
552556 if rv is not rv_to_marginalize
553557 ]
554558
555- if all ( rv_to_marginalize .type .broadcastable ) :
559+ if rv_to_marginalize .type .ndim == 0 :
556560 ndim_supp = max ([dependent_rv .type .ndim for dependent_rv in dependent_rvs ])
557561 else :
558562 # If the marginalized RV has multiple dimensions, check that graph between
@@ -561,23 +565,27 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
561565 dependent_rvs_dim_connections = subgraph_dim_connection (
562566 rv_to_marginalize , other_direct_rv_ancestors , dependent_rvs
563567 )
564- # dependent_rvs_dim_connections = subgraph_dim_connection(
565- # rv_to_marginalize, other_inputs, dependent_rvs
566- # )
567568
568- ndim_supp = max ((dependent_rv .type .ndim - rv_to_marginalize .type .ndim ) for dependent_rv in dependent_rvs )
569+ ndim_supp = max (
570+ (dependent_rv .type .ndim - rv_to_marginalize .type .ndim ) for dependent_rv in dependent_rvs
571+ )
569572
570- if any (len (dim ) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections ):
573+ if any (
574+ len (dim ) > 1
575+ for rv_dim_connections in dependent_rvs_dim_connections
576+ for dim in rv_dim_connections
577+ ):
571578 raise NotImplementedError ("Multiple dimensions are mixed" )
572579
573580 # We further check that:
574581 # 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
575582 # 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
576583 # show up on the right, so that collapsing logic in logp can be more straightforward.
577- # This also ensures the MarginalizedRV still behaves as an RV itself
578584 marginal_batch_ndim = rv_to_marginalize .owner .op .batch_ndim (rv_to_marginalize .owner )
579585 marginal_batch_dims = tuple ((i ,) for i in range (marginal_batch_ndim ))
580- for dependent_rv , dependent_rv_batch_dims in zip (dependent_rvs , dependent_rvs_dim_connections ):
586+ for dependent_rv , dependent_rv_batch_dims in zip (
587+ dependent_rvs , dependent_rvs_dim_connections
588+ ):
581589 extra_batch_ndim = dependent_rv .type .ndim - marginal_batch_ndim
582590 valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim )
583591 if dependent_rv_batch_dims != valid_dependent_batch_dims :
@@ -587,47 +595,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
587595 )
588596
589597 input_rvs = [* marginalized_rv_input_rvs , * other_direct_rv_ancestors ]
590- rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
598+ output_rvs = [rv_to_marginalize , * dependent_rvs ]
591599
592- outputs = rvs_to_marginalize
593600 # We are strict about shared variables in SymbolicRandomVariables
594- inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
595- # inputs = [
596- # inp
597- # for rv in rvs_to_marginalize # should be toposort
598- # for inp in rv.owner.inputs
599- # if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
600- # ]
601- # inputs = [
602- # inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
603- # if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
604- # ]
605- # inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
606- # # inp
607- # # for output in outputs
608- # # for inp in output.owner.inputs
609- # # ])
610- # inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
601+ inputs = input_rvs + collect_shared_vars (output_rvs , blockers = input_rvs )
602+
611603 if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
612604 marginalize_constructor = DiscreteMarginalMarkovChainRV
613605 else :
614606 marginalize_constructor = FiniteDiscreteMarginalRV
615607
616608 marginalization_op = marginalize_constructor (
617609 inputs = inputs ,
618- outputs = outputs ,
610+ outputs = output_rvs , # TODO: Add RNG updates to outputs
619611 ndim_supp = ndim_supp ,
620612 )
621-
622- marginalized_rvs = marginalization_op (* inputs )
623- print ()
624- import pytensor
625- pytensor .dprint (marginalized_rvs , print_type = True )
626- fgraph .replace_all (reversed (tuple (zip (rvs_to_marginalize , marginalized_rvs ))))
627- # assert 0
628- # fgraph.dprint()
629- # assert 0
630- # toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
631- # assert 0
632- return rvs_to_marginalize , marginalized_rvs
633-
613+ new_output_rvs = marginalization_op (* inputs )
614+ fgraph .replace_all (tuple (zip (output_rvs , new_output_rvs )))
615+ return output_rvs , new_output_rvs
0 commit comments