|
34 | 34 |
|
35 | 35 | from pytensor.compile import DeepCopyOp, Function, get_mode |
36 | 36 | from pytensor.compile.sharedvalue import SharedVariable |
37 | | -from pytensor.graph.basic import Constant, Variable, graph_inputs |
| 37 | +from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs |
38 | 38 | from pytensor.tensor.random.op import RandomVariable |
39 | 39 | from pytensor.tensor.random.type import RandomType |
40 | 40 | from pytensor.tensor.variable import TensorConstant, TensorVariable |
@@ -1241,15 +1241,13 @@ def register_rv( |
1241 | 1241 | self.add_named_variable(rv_var, dims) |
1242 | 1242 | self.set_initval(rv_var, initval) |
1243 | 1243 | else: |
1244 | | - if ( |
1245 | | - isinstance(observed, TensorVariable) |
1246 | | - and observed.owner is not None |
1247 | | - and isinstance(observed.owner.op, MinibatchOp) |
1248 | | - and total_size is None |
1249 | | - ): |
1250 | | - warnings.warn( |
1251 | | - f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" |
1252 | | - ) |
| 1244 | + if total_size is None and isinstance(observed, TensorVariable): |
| 1245 | + for node in ancestors([observed]): |
| 1246 | + if node.owner is not None and isinstance(node.owner.op, MinibatchOp): |
| 1247 | + warnings.warn( |
| 1248 | + f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" |
| 1249 | + ) |
| 1250 | + break |
1253 | 1251 | if not is_valid_observed(observed): |
1254 | 1252 | raise TypeError( |
1255 | 1253 | "Variables that depend on other nodes cannot be used for observed data." |
|
0 commit comments