| 
115 | 115 |     simplify_mul,  | 
116 | 116 | )  | 
117 | 117 | from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape  | 
 | 118 | +from pytensor.tensor.slinalg import BlockDiagonal  | 
118 | 119 | from pytensor.tensor.type import (  | 
119 | 120 |     TensorType,  | 
120 | 121 |     cmatrix,  | 
@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):  | 
4745 | 4746 |         out.eval({a: a_test, b: b_test}, mode=test_mode),  | 
4746 | 4747 |         rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),  | 
4747 | 4748 |     )  | 
 | 4749 | + | 
 | 4750 | + | 
 | 4751 | +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])  | 
 | 4752 | +@pytest.mark.parametrize(  | 
 | 4753 | +    "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]  | 
 | 4754 | +)  | 
 | 4755 | +@pytest.mark.parametrize(  | 
 | 4756 | +    "batch_other", [True, False], ids=["batched_other", "unbatched_other"]  | 
 | 4757 | +)  | 
 | 4758 | +def test_local_block_diag_dot_to_dot_block_diag(  | 
 | 4759 | +    left_multiply, batch_blockdiag, batch_other  | 
 | 4760 | +):  | 
 | 4761 | +    """  | 
 | 4762 | +    Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))  | 
 | 4763 | +    """  | 
 | 4764 | + | 
 | 4765 | +    def has_blockdiag(graph):  | 
 | 4766 | +        return any(  | 
 | 4767 | +            (  | 
 | 4768 | +                var.owner  | 
 | 4769 | +                and (  | 
 | 4770 | +                    isinstance(var.owner.op, BlockDiagonal)  | 
 | 4771 | +                    or (  | 
 | 4772 | +                        isinstance(var.owner.op, Blockwise)  | 
 | 4773 | +                        and isinstance(var.owner.op.core_op, BlockDiagonal)  | 
 | 4774 | +                    )  | 
 | 4775 | +                )  | 
 | 4776 | +            )  | 
 | 4777 | +            for var in ancestors([graph])  | 
 | 4778 | +        )  | 
 | 4779 | + | 
 | 4780 | +    a = tensor("a", shape=(4, 2))  | 
 | 4781 | +    b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))  | 
 | 4782 | +    c = tensor("c", shape=(4, 4))  | 
 | 4783 | +    x = pt.linalg.block_diag(a, b, c)  | 
 | 4784 | + | 
 | 4785 | +    d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))  | 
 | 4786 | + | 
 | 4787 | +    # Test multiple clients are all rewritten  | 
 | 4788 | +    if left_multiply:  | 
 | 4789 | +        out = x @ d  | 
 | 4790 | +    else:  | 
 | 4791 | +        out = d @ x  | 
 | 4792 | + | 
 | 4793 | +    assert has_blockdiag(out)  | 
 | 4794 | +    fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)  | 
 | 4795 | +    assert not has_blockdiag(fn.maker.fgraph.outputs[0])  | 
 | 4796 | + | 
 | 4797 | +    n_dots_rewrite = sum(  | 
 | 4798 | +        isinstance(node.op, Dot | Dot22)  | 
 | 4799 | +        or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))  | 
 | 4800 | +        for node in fn.maker.fgraph.apply_nodes  | 
 | 4801 | +    )  | 
 | 4802 | +    assert n_dots_rewrite == 3  | 
 | 4803 | + | 
 | 4804 | +    fn_expected = pytensor.function(  | 
 | 4805 | +        [a, b, c, d],  | 
 | 4806 | +        out,  | 
 | 4807 | +        mode=Mode(linker="py", optimizer=None),  | 
 | 4808 | +    )  | 
 | 4809 | +    assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])  | 
 | 4810 | + | 
 | 4811 | +    n_dots_no_rewrite = sum(  | 
 | 4812 | +        isinstance(node.op, Dot | Dot22)  | 
 | 4813 | +        or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))  | 
 | 4814 | +        for node in fn_expected.maker.fgraph.apply_nodes  | 
 | 4815 | +    )  | 
 | 4816 | +    assert n_dots_no_rewrite == 1  | 
 | 4817 | + | 
 | 4818 | +    rng = np.random.default_rng()  | 
 | 4819 | +    a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)  | 
 | 4820 | +    b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)  | 
 | 4821 | +    c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)  | 
 | 4822 | +    d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)  | 
 | 4823 | + | 
 | 4824 | +    rewrite_out = fn(a_val, b_val, c_val, d_val)  | 
 | 4825 | +    expected_out = fn_expected(a_val, b_val, c_val, d_val)  | 
 | 4826 | +    np.testing.assert_allclose(  | 
 | 4827 | +        rewrite_out,  | 
 | 4828 | +        expected_out,  | 
 | 4829 | +        atol=1e-6 if config.floatX == "float32" else 1e-12,  | 
 | 4830 | +        rtol=1e-6 if config.floatX == "float32" else 1e-12,  | 
 | 4831 | +    )  | 
 | 4832 | + | 
 | 4833 | + | 
 | 4834 | +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])  | 
 | 4835 | +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])  | 
 | 4836 | +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):  | 
 | 4837 | +    rng = np.random.default_rng()  | 
 | 4838 | +    a_size = int(rng.uniform(0, size))  | 
 | 4839 | +    b_size = int(rng.uniform(0, size - a_size))  | 
 | 4840 | +    c_size = size - a_size - b_size  | 
 | 4841 | + | 
 | 4842 | +    a = tensor("a", shape=(a_size, a_size))  | 
 | 4843 | +    b = tensor("b", shape=(b_size, b_size))  | 
 | 4844 | +    c = tensor("c", shape=(c_size, c_size))  | 
 | 4845 | +    d = tensor("d", shape=(size,))  | 
 | 4846 | + | 
 | 4847 | +    x = pt.linalg.block_diag(a, b, c)  | 
 | 4848 | +    out = x @ d  | 
 | 4849 | + | 
 | 4850 | +    mode = get_default_mode()  | 
 | 4851 | +    if not rewrite:  | 
 | 4852 | +        mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")  | 
 | 4853 | +    fn = pytensor.function([a, b, c, d], out, mode=mode)  | 
 | 4854 | + | 
 | 4855 | +    a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)  | 
 | 4856 | +    b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)  | 
 | 4857 | +    c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)  | 
 | 4858 | +    d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)  | 
 | 4859 | + | 
 | 4860 | +    benchmark(  | 
 | 4861 | +        fn,  | 
 | 4862 | +        a_val,  | 
 | 4863 | +        b_val,  | 
 | 4864 | +        c_val,  | 
 | 4865 | +        d_val,  | 
 | 4866 | +    )  | 
0 commit comments