|
16 | 16 | from pytensor.gradient import grad, hessian |
17 | 17 | from pytensor.graph.basic import Apply |
18 | 18 | from pytensor.graph.op import Op |
| 19 | +from pytensor.graph.replace import clone_replace |
19 | 20 | from pytensor.misc.safe_asarray import _asarray |
20 | 21 | from pytensor.raise_op import Assert |
21 | 22 | from pytensor.scalar import autocast_float, autocast_float_as |
@@ -818,6 +819,20 @@ def test_full(self): |
818 | 819 | res = pytensor.function([], full_at, mode=self.mode)() |
819 | 820 | assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) |
820 | 821 |
|
| 822 | + @pytest.mark.parametrize("func", (at.zeros, at.empty)) |
| 823 | + def test_rebuild(self, func): |
| 824 | + x = vector(shape=(50,)) |
| 825 | + x_test = np.zeros((50,)) |
| 826 | + y = func(x.shape) |
| 827 | + assert y.shape.eval({x: x_test}) == (50,) |
| 828 | + assert y.eval({x: x_test}).shape == (50,) |
| 829 | + |
| 830 | + x_new = vector(shape=(100,)) |
| 831 | + x_new_test = np.zeros((100,)) |
| 832 | + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) |
| 833 | + assert y_new.shape.eval({x_new: x_new_test}) == (100,) |
| 834 | + assert y_new.eval({x_new: x_new_test}).shape == (100,) |
| 835 | + |
821 | 836 |
|
822 | 837 | def test_infer_shape(): |
823 | 838 | with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): |
|
0 commit comments