|
| 1 | +from itertools import zip_longest |
| 2 | + |
1 | 3 | from pytensor.compile import optdb |
2 | 4 | from pytensor.configdefaults import config |
3 | 5 | from pytensor.graph.op import compute_test_value |
4 | 6 | from pytensor.graph.rewriting.basic import in2out, node_rewriter |
| 7 | +from pytensor.tensor import NoneConst |
5 | 8 | from pytensor.tensor.basic import constant, get_vector_length |
6 | 9 | from pytensor.tensor.elemwise import DimShuffle |
7 | 10 | from pytensor.tensor.extra_ops import broadcast_to |
|
17 | 20 | get_idx_list, |
18 | 21 | indexed_result_shape, |
19 | 22 | ) |
| 23 | +from pytensor.tensor.type_other import SliceType |
20 | 24 |
|
21 | 25 |
|
22 | 26 | def is_rv_used_in_graph(base_rv, node, fgraph): |
@@ -196,141 +200,104 @@ def local_dimshuffle_rv_lift(fgraph, node): |
196 | 200 | def local_subtensor_rv_lift(fgraph, node): |
197 | 201 | """Lift a ``*Subtensor`` through ``RandomVariable`` inputs. |
198 | 202 |
|
199 | | - In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions |
200 | | - need to be separated into distinct replication-space and (independent) |
201 | | - parameter-space ``*Subtensor``s. |
202 | | -
|
203 | | - The replication-space ``*Subtensor`` can be used to determine a |
204 | | - sub/super-set of the replication-space and, thus, a "smaller"/"larger" |
205 | | - ``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and |
206 | | - applied to the distribution parameters. |
207 | | -
|
208 | | - Consider the following example graph: |
209 | | - ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The |
210 | | - ``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``, |
211 | | - which correspond to all three ``size`` dimensions. Now, depending on the |
212 | | - broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op`` |
213 | | - could be reducing the ``size`` parameter and/or sub-setting the independent |
214 | | - ``mu`` and ``std`` parameters. Only once the dimensions are properly |
215 | | - separated into the two replication/parameter subspaces can we determine how |
216 | | - the ``*Subtensor`` indices are distributed. |
217 | | - For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]`` |
218 | | - could become |
219 | | - ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))`` |
220 | | - if ``mu.shape == std.shape == ()`` |
221 | | -
|
222 | | - ``normal`` is a rather simple case, because it's univariate. Multivariate |
223 | | - cases require a mapping between the parameter space and the image of the |
224 | | - random variable. This may not always be possible, but for many common |
225 | | - distributions it is. For example, the dimensions of the multivariate |
226 | | - normal's image can be mapped directly to each dimension of its parameters. |
227 | | - We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]`` |
228 | | - into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. |
| 203 | + For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``. |
229 | 204 |
|
| 205 | + This rewrite also applies to multivariate distributions as long |
| 206 | + as indexing does not happen within core dimensions, such as in |
| 207 | + ``mvnormal(mu, cov, size=(2,))[0, 0].`` |
230 | 208 | """ |
231 | 209 |
|
232 | 210 | st_op = node.op |
233 | 211 |
|
234 | 212 | if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): |
235 | 213 | return False |
236 | 214 |
|
237 | | - base_rv = node.inputs[0] |
| 215 | + rv = node.inputs[0] |
| 216 | + rv_node = rv.owner |
238 | 217 |
|
239 | | - rv_node = base_rv.owner |
240 | 218 | if not (rv_node and isinstance(rv_node.op, RandomVariable)): |
241 | 219 | return False |
242 | 220 |
|
243 | | - # If no one else is using the underlying `RandomVariable`, then we can |
244 | | - # do this; otherwise, the graph would be internally inconsistent. |
245 | | - if is_rv_used_in_graph(base_rv, node, fgraph): |
246 | | - return False |
247 | | - |
248 | 221 | rv_op = rv_node.op |
249 | 222 | rng, size, dtype, *dist_params = rv_node.inputs |
250 | 223 |
|
251 | | - # TODO: Remove this once the multi-dimensional changes described below are |
252 | | - # in place. |
253 | | - if rv_op.ndim_supp > 0: |
254 | | - return False |
255 | | - |
256 | | - rv_op = base_rv.owner.op |
257 | | - rng, size, dtype, *dist_params = base_rv.owner.inputs |
258 | | - |
| 224 | + # Parse indices |
259 | 225 | idx_list = getattr(st_op, "idx_list", None) |
260 | 226 | if idx_list: |
261 | 227 | cdata = get_idx_list(node.inputs, idx_list) |
262 | 228 | else: |
263 | 229 | cdata = node.inputs[1:] |
264 | | - |
265 | 230 | st_indices, st_is_bool = zip( |
266 | 231 | *tuple( |
267 | 232 | (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata |
268 | 233 | ) |
269 | 234 | ) |
270 | 235 |
|
271 | | - # We need to separate dimensions into replications and independents |
272 | | - num_ind_dims = None |
273 | | - if len(dist_params) == 1: |
274 | | - num_ind_dims = dist_params[0].ndim |
275 | | - else: |
276 | | - # When there is more than one distribution parameter, assume that all |
277 | | - # of them will broadcast to the maximum number of dimensions |
278 | | - num_ind_dims = max(d.ndim for d in dist_params) |
279 | | - |
280 | | - reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) |
281 | | - |
282 | | - if len(st_indices) > reps_ind_split_idx: |
283 | | - # These are the indices that need to be applied to the parameters |
284 | | - ind_indices = tuple(st_indices[reps_ind_split_idx:]) |
285 | | - |
286 | | - # We need to broadcast the parameters before applying the `*Subtensor*` |
287 | | - # with these indices, because the indices could be referencing broadcast |
288 | | - # dimensions that don't exist (yet) |
289 | | - bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params) |
290 | | - |
291 | | - # TODO: For multidimensional distributions, we need a map that tells us |
292 | | - # which dimensions of the parameters need to be indexed. |
293 | | - # |
294 | | - # For example, `multivariate_normal` would have the following: |
295 | | - # `RandomVariable.param_to_image_dims = ((0,), (0, 1))` |
296 | | - # |
297 | | - # I.e. the first parameter's (i.e. mean's) first dimension maps directly to |
298 | | - # the dimension of the RV's image, and its second parameter's |
299 | | - # (i.e. covariance's) first and second dimensions map directly to the |
300 | | - # dimension of the RV's image. |
301 | | - |
302 | | - args_lifted = tuple(p[ind_indices] for p in bcast_dist_params) |
303 | | - else: |
304 | | - # In this case, no indexing is applied to the parameters; only the |
305 | | - # `size` parameter is affected. |
306 | | - args_lifted = dist_params |
| 236 | + # Check that indexing does not act on support dims |
| 237 | + batched_ndims = rv.ndim - rv_op.ndim_supp |
| 238 | + if len(st_indices) > batched_ndims: |
| 239 | + # If the last indexes are just dummy `slice(None)` we discard them |
| 240 | + st_is_bool = st_is_bool[:batched_ndims] |
| 241 | + st_indices, supp_indices = ( |
| 242 | + st_indices[:batched_ndims], |
| 243 | + st_indices[batched_ndims:], |
| 244 | + ) |
| 245 | + for index in supp_indices: |
| 246 | + if not ( |
| 247 | + isinstance(index.type, SliceType) |
| 248 | + and all(NoneConst.equals(i) for i in index.owner.inputs) |
| 249 | + ): |
| 250 | + return False |
| 251 | + |
| 252 | + # If no one else is using the underlying `RandomVariable`, then we can |
| 253 | + # do this; otherwise, the graph would be internally inconsistent. |
| 254 | + if is_rv_used_in_graph(rv, node, fgraph): |
| 255 | + return False |
307 | 256 |
|
| 257 | + # Update the size to reflect the indexed dimensions |
308 | 258 | # TODO: Could use `ShapeFeature` info. We would need to be sure that |
309 | 259 | # `node` isn't in the results, though. |
310 | 260 | # if hasattr(fgraph, "shape_feature"): |
311 | 261 | # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) |
312 | 262 | # else: |
313 | | - output_shape = indexed_result_shape(base_rv.shape, st_indices) |
314 | | - |
315 | | - size_lifted = ( |
316 | | - output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] |
| 263 | + output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices) |
| 264 | + new_size_ignoring_boolean = ( |
| 265 | + output_shape_ignoring_bool |
| 266 | + if rv_op.ndim_supp == 0 |
| 267 | + else output_shape_ignoring_bool[: -rv_op.ndim_supp] |
317 | 268 | ) |
318 | 269 |
|
319 | | - # Boolean indices can actually change the `size` value (compared to just |
320 | | - # *which* dimensions of `size` are used). |
| 270 | + # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). |
| 271 | + # The `indexed_result_shape` helper does not consider this |
321 | 272 | if any(st_is_bool): |
322 | | - size_lifted = tuple( |
| 273 | + new_size = tuple( |
323 | 274 | at_sum(idx) if is_bool else s |
324 | | - for s, is_bool, idx in zip( |
325 | | - size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] |
| 275 | + for s, is_bool, idx in zip_longest( |
| 276 | + new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False |
326 | 277 | ) |
327 | 278 | ) |
| 279 | + else: |
| 280 | + new_size = new_size_ignoring_boolean |
328 | 281 |
|
329 | | - new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) |
330 | | - _, new_rv = new_node.outputs |
| 282 | + # Update the parameters to reflect the indexed dimensions |
| 283 | + new_dist_params = [] |
| 284 | + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): |
| 285 | + # Apply indexing on the batched dimensions of the parameter |
| 286 | + batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp) |
| 287 | + batched_param = shape_padleft(param, batched_param_dims_missing) |
| 288 | + batched_st_indices = [] |
| 289 | + for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape): |
| 290 | + # If we have a degenerate dimension indexing it should always do the job |
| 291 | + if batched_param_shape == 1: |
| 292 | + batched_st_indices.append(0) |
| 293 | + else: |
| 294 | + batched_st_indices.append(st_index) |
| 295 | + new_dist_params.append(batched_param[tuple(batched_st_indices)]) |
| 296 | + |
| 297 | + # Create new RV |
| 298 | + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) |
| 299 | + new_rv = new_node.default_output() |
331 | 300 |
|
332 | | - # Calling `Op.make_node` directly circumvents test value computations, so |
333 | | - # we need to compute the test values manually |
334 | 301 | if config.compute_test_value != "off": |
335 | 302 | compute_test_value(new_node) |
336 | 303 |
|
|
0 commit comments