88from pytensor .tensor .math import sum as at_sum
99from pytensor .tensor .random .op import RandomVariable
1010from pytensor .tensor .random .utils import broadcast_params
11- from pytensor .tensor .shape import Shape , Shape_i
11+ from pytensor .tensor .shape import Shape , Shape_i , shape_padleft
1212from pytensor .tensor .subtensor import (
1313 AdvancedSubtensor ,
1414 AdvancedSubtensor1 ,
@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
115115
116116 For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
117117
118- The basic idea behind this rewrite is that we need to separate the
119- ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
120- distinct sub-spaces: the (set of independent) parameters and ``size``
121- (i.e. replications) sub-spaces.
122-
123- If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
124- don't do anything.
125-
126- Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
127- those sub-spaces, we can break it apart and apply the parameter-space
128- ``DimShuffle`` to the distribution parameters, and then apply the
129- replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
130- particularly simple rearranging of a tuple, but the former requires a
131- little more work.
132-
133- TODO: Currently, multivariate support for this rewrite is disabled.
118+ This rewrite is only applicable when the Dimshuffle operation does
119+ not affect support dimensions.
134120
121+ TODO: Support dimension dropping
135122 """
136123
137124 ds_op = node .op
@@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
142129 base_rv = node .inputs [0 ]
143130 rv_node = base_rv .owner
144131
145- if not (
146- rv_node and isinstance (rv_node .op , RandomVariable ) and rv_node .op .ndim_supp == 0
147- ):
132+ if not (rv_node and isinstance (rv_node .op , RandomVariable )):
148133 return False
149134
150- # If no one else is using the underlying `RandomVariable`, then we can
151- # do this; otherwise, the graph would be internally inconsistent.
152- if is_rv_used_in_graph (base_rv , node , fgraph ):
135+ # Dimshuffle which drop dimensions not supported yet
136+ if ds_op .drop :
153137 return False
154138
155139 rv_op = rv_node .op
156140 rng , size , dtype , * dist_params = rv_node .inputs
141+ rv = rv_node .default_output ()
157142
158- # We need to know the dimensions that were *not* added by the `size`
159- # parameter (i.e. the dimensions corresponding to independent variates with
160- # different parameter values)
161- num_ind_dims = None
162- if len (dist_params ) == 1 :
163- num_ind_dims = dist_params [0 ].ndim
164- else :
165- # When there is more than one distribution parameter, assume that all
166- # of them will broadcast to the maximum number of dimensions
167- num_ind_dims = max (d .ndim for d in dist_params )
168-
169- # If the indices in `ds_new_order` are entirely within the replication
170- # indices group or the independent variates indices group, then we can apply
171- # this rewrite.
172-
173- ds_new_order = ds_op .new_order
174- # Create a map from old index order to new/`DimShuffled` index order
175- dim_orders = [(n , d ) for n , d in enumerate (ds_new_order ) if isinstance (d , int )]
176-
177- # Find the index at which the replications/independents split occurs
178- reps_ind_split_idx = len (dim_orders ) - (num_ind_dims + rv_op .ndim_supp )
179-
180- ds_reps_new_dims = dim_orders [:reps_ind_split_idx ]
181- ds_ind_new_dims = dim_orders [reps_ind_split_idx :]
182- ds_in_ind_space = ds_ind_new_dims and all (
183- d >= reps_ind_split_idx for n , d in ds_ind_new_dims
184- )
143+ # Check that Dimshuffle does not affect support dims
144+ supp_dims = set (range (rv .ndim - rv_op .ndim_supp , rv .ndim ))
145+ shuffled_dims = {dim for i , dim in enumerate (ds_op .shuffle ) if dim != i }
146+ augmented_dims = set (d - rv_op .ndim_supp for d in ds_op .augment )
147+ if (shuffled_dims | augmented_dims ) & supp_dims :
148+ return False
185149
186- if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims ):
150+ # If no one else is using the underlying RandomVariable, then we can
151+ # do this; otherwise, the graph would be internally inconsistent.
152+ if is_rv_used_in_graph (base_rv , node , fgraph ):
153+ return False
187154
188- # Update the `size` array to reflect the `DimShuffle`d dimensions,
189- # since the trailing dimensions in `size` represent the independent
190- # variates dimensions (for univariate distributions, at least)
191- has_size = get_vector_length (size ) > 0
192- new_size = (
193- [constant (1 , dtype = "int64" ) if o == "x" else size [o ] for o in ds_new_order ]
194- if has_size
195- else size
155+ batched_dims = rv .ndim - rv_op .ndim_supp
156+ batched_dims_ds_order = tuple (o for o in ds_op .new_order if o not in supp_dims )
157+
158+ # Make size explicit
159+ missing_size_dims = batched_dims - get_vector_length (size )
160+ if missing_size_dims > 0 :
161+ full_size = tuple (broadcast_params (dist_params , rv_op .ndims_params )[0 ].shape )
162+ size = full_size [:missing_size_dims ] + tuple (size )
163+
164+ # Update the size to reflect the DimShuffled dimensions
165+ new_size = [
166+ constant (1 , dtype = "int64" ) if o == "x" else size [o ]
167+ for o in batched_dims_ds_order
168+ ]
169+
170+ # Updates the params to reflect the Dimshuffled dimensions
171+ new_dist_params = []
172+ for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
173+ # Add broadcastable dimensions to the parameters that would have been expanded by the size
174+ padleft = batched_dims - (param .ndim - param_ndim_supp )
175+ if padleft > 0 :
176+ param = shape_padleft (param , padleft )
177+
178+ # Add the parameter support dimension indexes to the batched dimensions Dimshuffle
179+ param_new_order = batched_dims_ds_order + tuple (
180+ range (batched_dims , batched_dims + param_ndim_supp )
196181 )
182+ new_dist_params .append (param .dimshuffle (param_new_order ))
197183
198- # Compute the new axes parameter(s) for the `DimShuffle` that will be
199- # applied to the `RandomVariable` parameters (they need to be offset)
200- if ds_ind_new_dims :
201- rv_params_new_order = [
202- d - reps_ind_split_idx if isinstance (d , int ) else d
203- for d in ds_new_order [ds_ind_new_dims [0 ][0 ] :]
204- ]
205-
206- if not has_size and len (ds_new_order [: ds_ind_new_dims [0 ][0 ]]) > 0 :
207- # Additional broadcast dimensions need to be added to the
208- # independent dimensions (i.e. parameters), since there's no
209- # `size` to which they can be added
210- rv_params_new_order = (
211- list (ds_new_order [: ds_ind_new_dims [0 ][0 ]]) + rv_params_new_order
212- )
213- else :
214- # This case is reached when, for example, `ds_new_order` only
215- # consists of new broadcastable dimensions (i.e. `"x"`s)
216- rv_params_new_order = ds_new_order
217-
218- # Lift the `DimShuffle`s into the parameters
219- # NOTE: The parameters might not be broadcasted against each other, so
220- # we can only apply the parts of the `DimShuffle` that are relevant.
221- new_dist_params = []
222- for d in dist_params :
223- if d .ndim < len (ds_ind_new_dims ):
224- _rv_params_new_order = [
225- o
226- for o in rv_params_new_order
227- if (isinstance (o , int ) and o < d .ndim ) or o == "x"
228- ]
229- else :
230- _rv_params_new_order = rv_params_new_order
231-
232- new_dist_params .append (
233- type (ds_op )(d .type .broadcastable , _rv_params_new_order )(d )
234- )
235- new_node = rv_op .make_node (rng , new_size , dtype , * new_dist_params )
236-
237- if config .compute_test_value != "off" :
238- compute_test_value (new_node )
239-
240- out = new_node .outputs [1 ]
241- if base_rv .name :
242- out .name = f"{ base_rv .name } _lifted"
243- return [out ]
184+ new_node = rv_op .make_node (rng , new_size , dtype , * new_dist_params )
244185
245- ds_in_reps_space = ds_reps_new_dims and all (
246- d < reps_ind_split_idx for n , d in ds_reps_new_dims
247- )
248-
249- if ds_in_reps_space :
250- # Update the `size` array to reflect the `DimShuffle`d dimensions.
251- # There should be no need to `DimShuffle` now.
252- new_size = [
253- constant (1 , dtype = "int64" ) if o == "x" else size [o ] for o in ds_new_order
254- ]
255-
256- new_node = rv_op .make_node (rng , new_size , dtype , * dist_params )
257-
258- if config .compute_test_value != "off" :
259- compute_test_value (new_node )
260-
261- out = new_node .outputs [1 ]
262- if base_rv .name :
263- out .name = f"{ base_rv .name } _lifted"
264- return [out ]
186+ if config .compute_test_value != "off" :
187+ compute_test_value (new_node )
265188
266- return False
189+ out = new_node .outputs [1 ]
190+ if base_rv .name :
191+ out .name = f"{ base_rv .name } _lifted"
192+ return [out ]
267193
268194
269195@node_rewriter ([Subtensor , AdvancedSubtensor1 , AdvancedSubtensor ])
0 commit comments