2222
2323from pytensor import function
2424from pytensor .graph .basic import ancestors , walk
25- from pytensor .scalar .basic import Cast
26- from pytensor .tensor .elemwise import Elemwise
2725from pytensor .tensor .shape import Shape
2826from pytensor .tensor .variable import TensorVariable
2927
@@ -299,35 +297,28 @@ def make_compute_graph(
299297 self , var_names : Iterable [VarName ] | None = None
300298 ) -> dict [VarName , set [VarName ]]:
301299 """Get map of var_name -> set(input var names) for the model."""
300+ model = self .model
301+ named_vars = self ._all_vars
302302 input_map : dict [VarName , set [VarName ]] = defaultdict (set )
303303
304- for var_name in self .vars_to_plot (var_names ):
305- var = self .model [var_name ]
306- parent_name = self .get_parent_names (var )
307- input_map [var_name ] = input_map [var_name ].union (parent_name )
308-
309- if var in self .model .observed_RVs :
310- obs_node = self .model .rvs_to_values [var ]
311-
312- # loop created so that the elif block can go through this again
313- # and remove any intermediate ops, notably dtype casting, to observations
314- while True :
315- obs_name = obs_node .name
316- if obs_name and obs_name != var_name :
317- input_map [var_name ] = input_map [var_name ].difference ({obs_name })
318- input_map [obs_name ] = input_map [obs_name ].union ({var_name })
319- break
320- elif (
321- # for cases where observations are cast to a certain dtype
322- # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
323- obs_node .owner
324- and isinstance (obs_node .owner .op , Elemwise )
325- and isinstance (obs_node .owner .op .scalar_op , Cast )
326- ):
327- # we can retrieve the observation node by going up the graph
328- obs_node = obs_node .owner .inputs [0 ]
329- else :
330- break
304+ var_names_to_plot = self .vars_to_plot (var_names )
305+ for var_name in var_names_to_plot :
306+ parent_names = self .get_parent_names (model [var_name ])
307+ input_map [var_name ].update (parent_names )
308+
309+ for var_name in var_names_to_plot :
310+ if (var := model [var_name ]) in model .observed_RVs :
311+ # Make observed `Data` variables flow from the observed RV, and not the other way around
312+ # (In the generative graph they usually inform shape of the observed RV)
313+ # We have to iterate over the ancestors of the observed values because there can be
314+ # deterministic operations in between the `Data` variable and the observed value.
315+ obs_var = model .rvs_to_values [var ]
316+ for ancestor in ancestors ([obs_var ]):
317+ if ancestor not in named_vars :
318+ continue
319+ obs_name = cast (VarName , ancestor .name )
320+ input_map [var_name ].discard (obs_name )
321+ input_map [obs_name ].add (var_name )
331322
332323 return input_map
333324
@@ -348,7 +339,7 @@ def get_plates(
348339 plates = defaultdict (set )
349340
350341 # TODO: Evaluate all RV shapes at once
351- # This should help find discrepencies , and
342+ # This should help find discrepancies , and
352343 # avoids unnecessary function compiles for determining labels.
353344 dim_lengths : dict [str , int ] = {
354345 dim_name : fast_eval (value ).item () for dim_name , value in self .model .dim_lengths .items ()
0 commit comments