| 
1 | 1 | import warnings  | 
2 |  | -from collections.abc import Callable  | 
 | 2 | +from collections.abc import Callable, Sequence  | 
3 | 3 | from functools import partial  | 
4 |  | -from typing import Literal  | 
 | 4 | +from typing import Literal, cast  | 
5 | 5 | 
 
  | 
6 | 6 | import numpy as np  | 
7 | 7 | from numpy.core.numeric import normalize_axis_tuple  # type: ignore  | 
 | 
15 | 15 | from pytensor.tensor import math as ptm  | 
16 | 16 | from pytensor.tensor.basic import as_tensor_variable, diagonal  | 
17 | 17 | from pytensor.tensor.blockwise import Blockwise  | 
18 |  | -from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector  | 
 | 18 | +from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector  | 
19 | 19 | 
 
  | 
20 | 20 | 
 
  | 
21 | 21 | class MatrixPinv(Op):  | 
@@ -597,6 +597,109 @@ def infer_shape(self, fgraph, node, shapes):  | 
597 | 597 |         else:  | 
598 | 598 |             return [s_shape]  | 
599 | 599 | 
 
  | 
 | 600 | +    def L_op(  | 
 | 601 | +        self,  | 
 | 602 | +        inputs: Sequence[Variable],  | 
 | 603 | +        outputs: Sequence[Variable],  | 
 | 604 | +        output_grads: Sequence[Variable],  | 
 | 605 | +    ) -> list[Variable]:  | 
 | 606 | +        """  | 
 | 607 | +        Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:  | 
 | 608 | +        https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194  | 
 | 609 | +
  | 
 | 610 | +        And the mxnet implementation described in ..[1]  | 
 | 611 | +
  | 
 | 612 | +        References  | 
 | 613 | +        ----------  | 
 | 614 | +        .. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).  | 
 | 615 | +        """  | 
 | 616 | +        (A,) = (cast(ptb.TensorVariable, x) for x in inputs)  | 
 | 617 | + | 
 | 618 | +        if not self.compute_uv:  | 
 | 619 | +            # We need all the components of the SVD to compute the gradient of A even if we only use the singular values  | 
 | 620 | +            # in the cost function.  | 
 | 621 | +            U, s, VT = svd(A, full_matrices=False, compute_uv=True)  | 
 | 622 | + | 
 | 623 | +            (ds,) = (cast(ptb.TensorVariable, x) for x in output_grads)  | 
 | 624 | +            A_bar = (U.conj() * ds[..., None, :]) @ VT  | 
 | 625 | + | 
 | 626 | +            return [A_bar]  | 
 | 627 | + | 
 | 628 | +        elif self.full_matrices:  | 
 | 629 | +            raise NotImplementedError(  | 
 | 630 | +                "Gradient of svd not implemented for full_matrices=True"  | 
 | 631 | +            )  | 
 | 632 | + | 
 | 633 | +        else:  | 
 | 634 | +            U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)  | 
 | 635 | + | 
 | 636 | +            # Handle disconnected inputs  | 
 | 637 | +            # If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs  | 
 | 638 | +            # will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.  | 
 | 639 | +            new_output_grads = []  | 
 | 640 | +            is_disconnected = [  | 
 | 641 | +                isinstance(x.type, DisconnectedType) for x in output_grads  | 
 | 642 | +            ]  | 
 | 643 | +            if all(is_disconnected):  | 
 | 644 | +                return [DisconnectedType()()]  | 
 | 645 | + | 
 | 646 | +            for disconnected, output_grad, output in zip(  | 
 | 647 | +                is_disconnected, output_grads, outputs  | 
 | 648 | +            ):  | 
 | 649 | +                if disconnected:  | 
 | 650 | +                    new_output_grads.append(ptb.zeros_like(output))  | 
 | 651 | +                else:  | 
 | 652 | +                    new_output_grads.append(output_grad)  | 
 | 653 | + | 
 | 654 | +            (dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)  | 
 | 655 | + | 
 | 656 | +            V = VT.T  | 
 | 657 | +            dV = dVT.T  | 
 | 658 | + | 
 | 659 | +            m, n = A.shape[-2:]  | 
 | 660 | + | 
 | 661 | +            k = ptm.min((m, n))  | 
 | 662 | +            eye = ptb.eye(k)  | 
 | 663 | + | 
 | 664 | +            def h(t):  | 
 | 665 | +                """  | 
 | 666 | +                Approximation of s_i ** 2 - s_j ** 2, from .. [1].  | 
 | 667 | +                Robust to identical singular values (singular matrix input), although  | 
 | 668 | +                gradients are still wrong in this case.  | 
 | 669 | +                """  | 
 | 670 | +                eps = 1e-8  | 
 | 671 | + | 
 | 672 | +                # sign(0) = 0 in pytensor, which defeats the whole purpose of this function  | 
 | 673 | +                sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))  | 
 | 674 | +                return ptm.maximum(ptm.abs(t), eps) * sign_t  | 
 | 675 | + | 
 | 676 | +            numer = ptb.ones((k, k)) - eye  | 
 | 677 | +            denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])  | 
 | 678 | +            E = numer / denom  | 
 | 679 | + | 
 | 680 | +            utgu = U.T @ dU  | 
 | 681 | +            vtgv = VT @ dV  | 
 | 682 | + | 
 | 683 | +            A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]  | 
 | 684 | +            A_bar = A_bar + eye * ds[..., :, None]  | 
 | 685 | +            A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))  | 
 | 686 | +            A_bar = U.conj() @ A_bar @ VT  | 
 | 687 | + | 
 | 688 | +            A_bar = ptb.switch(  | 
 | 689 | +                ptm.eq(m, n),  | 
 | 690 | +                A_bar,  | 
 | 691 | +                ptb.switch(  | 
 | 692 | +                    ptm.lt(m, n),  | 
 | 693 | +                    A_bar  | 
 | 694 | +                    + (  | 
 | 695 | +                        U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)  | 
 | 696 | +                    ).conj(),  | 
 | 697 | +                    A_bar  | 
 | 698 | +                    + (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,  | 
 | 699 | +                ),  | 
 | 700 | +            )  | 
 | 701 | +            return [A_bar]  | 
 | 702 | + | 
600 | 703 | 
 
  | 
601 | 704 | def svd(a, full_matrices: bool = True, compute_uv: bool = True):  | 
602 | 705 |     """  | 
 | 
0 commit comments