@@ -3376,6 +3376,7 @@ def inverse_permutation(perm):
33763376 )
33773377
33783378
3379+ # TODO: optimization to insert ExtractDiag with view=True
33793380class ExtractDiag (Op ):
33803381 """
33813382 Return specified diagonals.
@@ -3526,8 +3527,12 @@ def __setstate__(self, state):
35263527 self .axis2 = 1
35273528
35283529
3529- extract_diag = ExtractDiag ()
3530- # TODO: optimization to insert ExtractDiag with view=True
3530+ def extract_diag (x ):
3531+ warnings .warn (
3532+ "pytensor.tensor.extract_diag is deprecated. Use pytensor.tensor.diagonal instead." ,
3533+ FutureWarning ,
3534+ )
3535+ return diagonal (x )
35313536
35323537
35333538def diagonal (a , offset = 0 , axis1 = 0 , axis2 = 1 ):
@@ -3554,6 +3559,15 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
35543559 return ExtractDiag (offset , axis1 , axis2 )(a )
35553560
35563561
3562+ def trace (a , offset = 0 , axis1 = 0 , axis2 = 1 ):
3563+ """
3564+ Returns the sum along diagonals of the array.
3565+
3566+ Equivalent to `numpy.trace`
3567+ """
3568+ return diagonal (a , offset = offset , axis1 = axis1 , axis2 = axis2 ).sum (- 1 )
3569+
3570+
35573571class AllocDiag (Op ):
35583572 """An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
35593573
@@ -4254,6 +4268,7 @@ def take_along_axis(arr, indices, axis=0):
42544268 "full_like" ,
42554269 "empty" ,
42564270 "empty_like" ,
4271+ "trace" ,
42574272 "tril_indices" ,
42584273 "tril_indices_from" ,
42594274 "triu_indices" ,
0 commit comments