@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op):
34813481 permutation instead.
34823482 """
34833483
3484- __props__ = ()
3484+ __props__ = ("inverse" ,)
3485+
3486+ def __init__ (self , inverse : bool ):
3487+ super ().__init__ ()
3488+ self .inverse = inverse
34853489
3486- def make_node (self , x , y , inverse ):
3490+ def make_node (self , x , y ):
34873491 x = as_tensor_variable (x )
34883492 y = as_tensor_variable (y )
3489- if inverse : # as_tensor_variable does not accept booleans
3490- inverse = as_tensor_variable (1 )
3491- else :
3492- inverse = as_tensor_variable (0 )
34933493
34943494 # y should contain integers
34953495 assert y .type .dtype in integer_dtypes
3496- # Inverse should be an integer scalar
3497- assert inverse .type .ndim == 0 and inverse .type .dtype in integer_dtypes
34983496
34993497 # Match shapes of x and y
35003498 x_dim = x .type .ndim
@@ -3511,7 +3509,7 @@ def make_node(self, x, y, inverse):
35113509 ]
35123510 out_type = tensor (dtype = x .type .dtype , shape = out_shape )
35133511
3514- inputlist = [x , y , inverse ]
3512+ inputlist = [x , y ]
35153513 outputlist = [out_type ]
35163514 return Apply (self , inputlist , outputlist )
35173515
@@ -3564,7 +3562,7 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
35643562 raise ValueError (f"Dimension mismatch: { xs0 } , { ys0 } " )
35653563
35663564 def perform (self , node , inp , out ):
3567- x , y , inverse = inp
3565+ x , y = inp
35683566 (outs ,) = out
35693567 x_s = x .shape
35703568 y_s = y .shape
@@ -3587,7 +3585,7 @@ def perform(self, node, inp, out):
35873585 if outs [0 ] is None or outs [0 ].shape != out_s :
35883586 outs [0 ] = np .empty (out_s , dtype = x .dtype )
35893587
3590- self ._rec_perform (node , x , y , inverse , outs [0 ], curdim = 0 )
3588+ self ._rec_perform (node , x , y , self . inverse , outs [0 ], curdim = 0 )
35913589
35923590 def infer_shape (self , fgraph , node , in_shapes ):
35933591 from pytensor .tensor .math import maximum
@@ -3599,14 +3597,14 @@ def infer_shape(self, fgraph, node, in_shapes):
35993597 return [out_shape ]
36003598
36013599 def grad (self , inp , grads ):
3602- from pytensor .tensor .math import Sum , eq
3600+ from pytensor .tensor .math import Sum
36033601
3604- x , y , inverse = inp
3602+ x , y = inp
36053603 (gz ,) = grads
36063604 # First, compute the gradient wrt the broadcasted x.
36073605 # If 'inverse' is False (0), apply the inverse of y on gz.
36083606 # Else, apply y on gz.
3609- gx = permute_row_elements (gz , y , eq ( inverse , 0 ) )
3607+ gx = permute_row_elements (gz , y , not self . inverse )
36103608
36113609 # If x has been broadcasted along some axes, we need to sum
36123610 # the gradient over these axes, but keep the dimension (as
@@ -3643,20 +3641,17 @@ def grad(self, inp, grads):
36433641 if x .type .dtype in discrete_dtypes :
36443642 gx = x .zeros_like ()
36453643
3646- # The elements of y and of inverse both affect the output,
3644+ # The elements of y affect the output,
36473645 # so they are connected to the output,
36483646 # and the transformation isn't defined if their values
36493647 # are non-integer, so the gradient with respect to them is
36503648 # undefined
36513649
3652- return [gx , grad_undefined (self , 1 , y ), grad_undefined (self , 1 , inverse )]
3653-
3654-
3655- _permute_row_elements = PermuteRowElements ()
3650+ return [gx , grad_undefined (self , 1 , y )]
36563651
36573652
3658- def permute_row_elements (x , y , inverse = 0 ):
3659- return _permute_row_elements ( x , y , inverse )
3653+ def permute_row_elements (x , y , inverse = False ):
3654+ return PermuteRowElements ( inverse = inverse )( x , y )
36603655
36613656
36623657def inverse_permutation (perm ):
0 commit comments