1+ from collections .abc import Sequence
12from copy import copy
23from textwrap import dedent
4+ from typing import Literal
35
46import numpy as np
57from numpy .core .numeric import normalize_axis_tuple
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
5456
5557 Parameters
5658 ----------
57- input_broadcastable
58- The expected broadcastable pattern of the input
59+ input_ndim
60+ The expected number of dimension of the input
5961 new_order
6062 A list representing the relationship between the input's
6163 dimensions and the output's dimensions. Each element of the
6264 list can either be an index or 'x'. Indices must be encoded
6365 as python integers, not pytensor symbolic integers.
64- inplace : bool, optional
65- If True (default), the output will be a view of the input.
66+ Missing indexes correspond to drop dimensions.
6667
6768 Notes
6869 -----
@@ -77,50 +78,45 @@ class DimShuffle(ExternalCOp):
7778
7879 .. code-block:: python
7980
80- DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
81+ DimShuffle(input_ndim=3, new_order= ["x", 2, "x", 0, 1])
8182
82- This `Op` will only work on 3d tensors with no broadcastable
83- dimensions. The first dimension will be broadcastable,
83+ This `Op` will only work on 3d tensors.
84+ The first dimension of the output will be broadcastable,
8485 then we will have the third dimension of the input tensor as
8586 the second of the resulting tensor, etc. If the tensor has
8687 shape (20, 30, 40), the resulting tensor will have dimensions
8788 (1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
8889
8990 .. code-block:: python
9091
91- DimShuffle((True, False), [1])
92+ DimShuffle(input_ndim=2, new_order= [1])
9293
93- This `Op` will only work on 2d tensors with the first dimension
94- broadcastable.
95- The second dimension of the input tensor will be the first dimension of
96- the resulting tensor.
97- If the tensor has shape (1, 20), the resulting tensor will have shape
98- (20, ).
94+ This `Op` will only work on 2d tensors with the first dimension broadcastable.
95+ The second dimension of the input tensor will be the first dimension of the resulting tensor.
96+ If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
9997
10098 Examples
10199 --------
102100 .. code-block:: python
103101
104- DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
105- DimShuffle((False, False), [0, 1]) # identity
106- DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
107- DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
108- # (N to 1xN)
109- DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
110- # (N to Nx1)
111- DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
112- DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
113- DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
114-
115- The reordering of the dimensions can be done with the numpy.transpose
116- function.
117- Adding, subtracting dimensions can be done with reshape.
102+ DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
103+ DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
104+ DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
105+ DimShuffle(input_ndim=1, new_order=["x", 0]) # make a row out of a 1d vector (N to 1xN)
106+ DimShuffle(input_ndim=1, new_order=[0, "x"]) # make a column out of a 1d vector (N to Nx1)
107+ DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
108+ DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
109+ DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
118110
111+ Notes
112+ -----
113+ The python implementation of this Op combines numpy.transpose for reordering of the dimensions
114+ and numpy.reshape for subtracting and adding broadcastable dimensions.
119115 """
120116
121117 _f16_ok = True
122118 check_input = False
123- __props__ = ("input_broadcastable " , "new_order" , "inplace" )
119+ __props__ = ("input_ndim " , "new_order" , "inplace" )
124120 c_func_file = "c_code/dimshuffle.c"
125121 c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
126122
@@ -133,16 +129,14 @@ def params_type(self):
133129 inplace = scalar_bool ,
134130 )
135131
136- def __init__ (self , input_broadcastable , new_order ):
132+ def __init__ (self , * , input_ndim : int , new_order : Sequence [ int | Literal [ "x" ]] ):
137133 super ().__init__ ([self .c_func_file ], self .c_func_name )
138134
139- self .input_broadcastable = tuple (input_broadcastable )
140- if not all (isinstance (bs , bool | np .bool_ ) for bs in self .input_broadcastable ):
141- raise ValueError (
142- f"input_broadcastable must be boolean, { self .input_broadcastable } "
143- )
144- self .new_order = tuple (new_order )
135+ if not isinstance (input_ndim , int ):
136+ raise TypeError (f"input_ndim must be an integer, got { type (int )} " )
145137
138+ self .input_ndim = input_ndim
139+ self .new_order = tuple (new_order )
146140 self .inplace = True
147141
148142 for i , j in enumerate (new_order ):
@@ -152,10 +146,10 @@ def __init__(self, input_broadcastable, new_order):
152146 "DimShuffle indices must be Python ints; got "
153147 f"{ j } of type { type (j )} ."
154148 )
155- if j >= len ( input_broadcastable ) :
149+ if j >= input_ndim :
156150 raise ValueError (
157151 f"new_order[{ i } ] is { j } , but the input only has "
158- f"{ len ( input_broadcastable ) } axes."
152+ f"{ input_ndim } axes."
159153 )
160154 if j in new_order [(i + 1 ) :]:
161155 raise ValueError (
@@ -164,19 +158,7 @@ def __init__(self, input_broadcastable, new_order):
164158 )
165159
166160 # List of input dimensions to drop
167- drop = []
168- for i , b in enumerate (input_broadcastable ):
169- if i not in new_order :
170- # We want to drop this dimension because it's not a value in
171- # `new_order`
172- if b == 1 :
173- drop .append (i )
174- else :
175- # We cannot drop non-broadcastable dimensions
176- raise ValueError (
177- "Cannot drop a non-broadcastable dimension: "
178- f"{ input_broadcastable } , { new_order } "
179- )
161+ drop = [i for i in range (input_ndim ) if i not in new_order ]
180162
181163 # This is the list of the original dimensions that we keep
182164 self .shuffle = [x for x in new_order if x != "x" ]
@@ -186,7 +168,6 @@ def __init__(self, input_broadcastable, new_order):
186168 self .augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
187169 self .drop = drop
188170
189- input_ndim = len (input_broadcastable )
190171 self .is_left_expand_dims = self .augment and (
191172 input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
192173 )
@@ -204,30 +185,29 @@ def __setstate__(self, state):
204185 # Let's just build the ExternalCOp.
205186 super ().__init__ ([self .c_func_file ], self .c_func_name )
206187
207- def make_node (self , _input ):
208- input = as_tensor_variable (_input )
209- ib = tuple (s == 1 for s in input .type .shape )
210- if ib != self .input_broadcastable :
211- if len (ib ) != len (self .input_broadcastable ):
188+ def make_node (self , inp ):
189+ input = as_tensor_variable (inp )
190+ if input .type .ndim != self .input_ndim :
191+ raise TypeError (
192+ "The number of dimensions of the input is incorrect for this op. "
193+ f"Expected { self .input_ndim } , got { input .type .ndim } ."
194+ )
195+
196+ input_static_shape = input .type .shape
197+
198+ # Runtime check for invalid drop
199+ for d in self .drop :
200+ if input_static_shape [d ] not in (1 , None ):
212201 raise TypeError (
213- "The number of dimensions of the "
214- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
202+ f"Input dropped dimension { d } must have length 1 but has { input_static_shape [d ]} "
215203 )
216- for expected , b in zip (self .input_broadcastable , ib ):
217- if expected and not b :
218- raise TypeError (
219- "The broadcastable pattern of the "
220- f"input is incorrect for this op. Expected { self .input_broadcastable } , got { ib } ."
221- )
222- # else, expected == b or not expected and b
223- # Both case are good.
224204
225205 out_static_shape = []
226206 for dim_idx in self .new_order :
227207 if dim_idx == "x" :
228208 out_static_shape .append (1 )
229209 else :
230- out_static_shape .append (input . type . shape [dim_idx ])
210+ out_static_shape .append (input_static_shape [dim_idx ])
231211
232212 output = TensorType (dtype = input .type .dtype , shape = out_static_shape )()
233213
@@ -254,12 +234,14 @@ def perform(self, node, inp, out):
254234 if not isinstance (res , np .ndarray | np .memmap ):
255235 raise TypeError (res )
256236
237+ # Put dropped axis at end
257238 res = res .transpose (self .transposition )
258239
259- shape = list (res .shape [: len (self .shuffle )])
240+ # Define new shape without dropped axis and including new ones
241+ new_shape = list (res .shape [: len (self .shuffle )])
260242 for augm in self .augment :
261- shape .insert (augm , 1 )
262- res = res .reshape (shape )
243+ new_shape .insert (augm , 1 )
244+ res = res .reshape (new_shape )
263245
264246 if not self .inplace :
265247 res = np .copy (res )
@@ -284,22 +266,15 @@ def R_op(self, inputs, eval_points):
284266 def grad (self , inp , grads ):
285267 (x ,) = inp
286268 (gz ,) = grads
287- gz = as_tensor_variable (gz )
288269 grad_order = ["x" ] * x .type .ndim
289270 for i , v in enumerate (self .new_order ):
290271 if v != "x" :
291272 grad_order [v ] = i
292- # Do not make the DimShuffle inplace as an optimization at the
293- # canonicalization optimization phase will remove the inplace.
294- # The inplace will be reintroduced automatically later in the graph.
295- if inp [0 ].dtype in discrete_dtypes :
296- return [inp [0 ].zeros_like (dtype = config .floatX )]
273+
274+ if x .type .dtype in discrete_dtypes :
275+ return [x .zeros_like (dtype = config .floatX )]
297276 else :
298- return [
299- DimShuffle (tuple (s == 1 for s in gz .type .shape ), grad_order )(
300- Elemwise (scalar_identity )(gz )
301- )
302- ]
277+ return [gz .dimshuffle (grad_order )]
303278
304279
305280class DimShufflePrinter (Printer ):
@@ -409,7 +384,7 @@ def __setstate__(self, d):
409384 self .nfunc = None
410385 self .inplace_pattern = frozendict (self .inplace_pattern )
411386
412- def get_output_info (self , dim_shuffle , * inputs ):
387+ def get_output_info (self , * inputs ):
413388 """Return the outputs dtype and broadcastable pattern and the
414389 dimshuffled inputs.
415390
@@ -427,12 +402,7 @@ def get_output_info(self, dim_shuffle, *inputs):
427402 if not difference :
428403 args .append (input )
429404 else :
430- args .append (
431- dim_shuffle (
432- input .type .broadcastable ,
433- ["x" ] * difference + list (range (length )),
434- )(input )
435- )
405+ args .append (input .dimshuffle (["x" ] * difference + list (range (length ))))
436406 inputs = args
437407
438408 # HERE: all the broadcast dims have the same length now
@@ -489,7 +459,7 @@ def make_node(self, *inputs):
489459 using DimShuffle.
490460 """
491461 inputs = [as_tensor_variable (i ) for i in inputs ]
492- out_dtypes , out_shapes , inputs = self .get_output_info (DimShuffle , * inputs )
462+ out_dtypes , out_shapes , inputs = self .get_output_info (* inputs )
493463 outputs = [
494464 TensorType (dtype = dtype , shape = shape )()
495465 for dtype , shape in zip (out_dtypes , out_shapes )
@@ -634,7 +604,7 @@ def transform(r):
634604 res = pytensor .tensor .basic .constant (
635605 np .asarray (r .data ), dtype = r .type .dtype
636606 )
637- return DimShuffle ((), ["x" ] * nd )( res )
607+ return res . dimshuffle ( ["x" ] * nd )
638608
639609 new_r = Elemwise (node .op , {})(* [transform (ipt ) for ipt in node .inputs ])
640610 if isinstance (new_r , list | tuple ):
@@ -1707,13 +1677,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
17071677 batched_ndims = x .type .ndim - node .inputs [0 ].type .ndim
17081678 if not batched_ndims :
17091679 return node .op .make_node (x )
1710- input_broadcastable = x .type .broadcastable [:batched_ndims ] + op .input_broadcastable
1711- # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1712- # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1680+ # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
1681+ # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
17131682 new_order = list (range (batched_ndims )) + [
17141683 "x" if (o == "x" ) else (o + batched_ndims ) for o in op .new_order
17151684 ]
1716- return DimShuffle ( input_broadcastable , new_order ).make_node ( x )
1685+ return x . dimshuffle ( new_order ).owner
17171686
17181687
17191688def get_normalized_batch_axes (
0 commit comments