|  | 
| 3 | 3 | from typing import cast | 
| 4 | 4 | 
 | 
| 5 | 5 | from pytensor import Variable | 
|  | 6 | +from pytensor import tensor as pt | 
| 6 | 7 | from pytensor.graph import Apply, FunctionGraph | 
| 7 | 8 | from pytensor.graph.rewriting.basic import ( | 
| 8 | 9 |     copy_stack_trace, | 
| @@ -611,3 +612,96 @@ def rewrite_inv_inv(fgraph, node): | 
| 611 | 612 |     ): | 
| 612 | 613 |         return None | 
| 613 | 614 |     return [potential_inner_inv.inputs[0]] | 
|  | 615 | + | 
|  | 616 | + | 
|  | 617 | +@register_canonicalize | 
|  | 618 | +@register_stabilize | 
|  | 619 | +@node_rewriter([Blockwise]) | 
|  | 620 | +def rewrite_inv_eye_to_eye(fgraph, node): | 
|  | 621 | +    """ | 
|  | 622 | +     This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself | 
|  | 623 | +    The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. | 
|  | 624 | +    Parameters | 
|  | 625 | +    ---------- | 
|  | 626 | +    fgraph: FunctionGraph | 
|  | 627 | +        Function graph being optimized | 
|  | 628 | +    node: Apply | 
|  | 629 | +        Node of the function graph to be optimized | 
|  | 630 | +    Returns | 
|  | 631 | +    ------- | 
|  | 632 | +    list of Variable, optional | 
|  | 633 | +        List of optimized variables, or None if no optimization was performed | 
|  | 634 | +    """ | 
|  | 635 | +    valid_inverses = (MatrixInverse, MatrixPinv) | 
|  | 636 | +    core_op = node.op.core_op | 
|  | 637 | +    if not (isinstance(core_op, valid_inverses)): | 
|  | 638 | +        return None | 
|  | 639 | + | 
|  | 640 | +    # Check whether input to inverse is Eye and the 1's are on main diagonal | 
|  | 641 | +    eye_check = node.inputs[0] | 
|  | 642 | +    if not ( | 
|  | 643 | +        eye_check.owner | 
|  | 644 | +        and isinstance(eye_check.owner.op, Eye) | 
|  | 645 | +        and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 | 
|  | 646 | +    ): | 
|  | 647 | +        return None | 
|  | 648 | +    return [eye_check] | 
|  | 649 | + | 
|  | 650 | + | 
|  | 651 | +@register_canonicalize | 
|  | 652 | +@register_stabilize | 
|  | 653 | +@node_rewriter([Blockwise]) | 
|  | 654 | +def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): | 
|  | 655 | +    """ | 
|  | 656 | +     This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. | 
|  | 657 | +     This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix | 
|  | 658 | +
 | 
|  | 659 | +    Parameters | 
|  | 660 | +    ---------- | 
|  | 661 | +    fgraph: FunctionGraph | 
|  | 662 | +        Function graph being optimized | 
|  | 663 | +    node: Apply | 
|  | 664 | +        Node of the function graph to be optimized | 
|  | 665 | +
 | 
|  | 666 | +    Returns | 
|  | 667 | +    ------- | 
|  | 668 | +    list of Variable, optional | 
|  | 669 | +        List of optimized variables, or None if no optimization was performed | 
|  | 670 | +    """ | 
|  | 671 | +    valid_inverses = (MatrixInverse, MatrixPinv) | 
|  | 672 | +    core_op = node.op.core_op | 
|  | 673 | +    if not (isinstance(core_op, valid_inverses)): | 
|  | 674 | +        return None | 
|  | 675 | + | 
|  | 676 | +    inputs = node.inputs[0] | 
|  | 677 | +    # Check for use of pt.diag first | 
|  | 678 | +    if ( | 
|  | 679 | +        inputs.owner | 
|  | 680 | +        and isinstance(inputs.owner.op, AllocDiag) | 
|  | 681 | +        and AllocDiag.is_offset_zero(inputs.owner) | 
|  | 682 | +    ): | 
|  | 683 | +        inv_input = inputs.owner.inputs[0] | 
|  | 684 | +        if inv_input.type.ndim == 1: | 
|  | 685 | +            inv_val = pt.diag(1 / inv_input) | 
|  | 686 | +            return [inv_val] | 
|  | 687 | + | 
|  | 688 | +    # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix | 
|  | 689 | +    inputs_or_none = _find_diag_from_eye_mul(inputs) | 
|  | 690 | +    if inputs_or_none is None: | 
|  | 691 | +        return None | 
|  | 692 | + | 
|  | 693 | +    eye_input, non_eye_inputs = inputs_or_none | 
|  | 694 | + | 
|  | 695 | +    # Dealing with only one other input | 
|  | 696 | +    if len(non_eye_inputs) != 1: | 
|  | 697 | +        return None | 
|  | 698 | + | 
|  | 699 | +    non_eye_input = non_eye_inputs[0] | 
|  | 700 | + | 
|  | 701 | +    # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those | 
|  | 702 | +    if non_eye_input.type.broadcastable[-2:] == (False, False): | 
|  | 703 | +        # For Matrix | 
|  | 704 | +        return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)] | 
|  | 705 | +    else: | 
|  | 706 | +        # For Vector or Scalar | 
|  | 707 | +        return [eye_input / non_eye_input] | 
0 commit comments