|
7 | 7 | from pytensor.configdefaults import config |
8 | 8 | from pytensor.graph.basic import Variable |
9 | 9 | from pytensor.graph.fg import FunctionGraph |
| 10 | +from pytensor.graph.replace import clone_replace |
10 | 11 | from pytensor.graph.type import Type |
11 | 12 | from pytensor.misc.safe_asarray import _asarray |
12 | 13 | from pytensor.scalar.basic import ScalarConstant |
@@ -337,6 +338,21 @@ def test_more_shapes(self): |
337 | 338 | Reshape, |
338 | 339 | ) |
339 | 340 |
|
| 341 | + def test_rebuild(self): |
| 342 | + x = as_tensor_variable(50) |
| 343 | + i = vector("i") |
| 344 | + i_test = np.zeros((100,), dtype=config.floatX) |
| 345 | + y = reshape(i, (100 // x, x)) |
| 346 | + assert y.type.shape == (2, 50) |
| 347 | + assert tuple(y.shape.eval({i: i_test})) == (2, 50) |
| 348 | + assert y.eval({i: i_test}).shape == (2, 50) |
| 349 | + |
| 350 | + x_new = as_tensor_variable(25) |
| 351 | + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) |
| 352 | + assert y_new.type.shape == (4, 25) |
| 353 | + assert tuple(y_new.shape.eval({i: i_test})) == (4, 25) |
| 354 | + assert y_new.eval({i: i_test}).shape == (4, 25) |
| 355 | + |
340 | 356 |
|
341 | 357 | def test_shape_i_hash(): |
342 | 358 | assert isinstance(Shape_i(np.int64(1)).__hash__(), int) |
@@ -524,6 +540,22 @@ def test_specify_shape_in_grad(self): |
524 | 540 | z_grad = grad(z.sum(), wrt=x) |
525 | 541 | assert isinstance(z_grad.owner.op, SpecifyShape) |
526 | 542 |
|
| 543 | + def test_rebuild(self): |
| 544 | + x = as_tensor_variable(50) |
| 545 | + i = matrix("i") |
| 546 | + i_test = np.zeros((4, 50), dtype=config.floatX) |
| 547 | + y = specify_shape(i, (None, x)) |
| 548 | + assert y.type.shape == (None, 50) |
| 549 | + assert tuple(y.shape.eval({i: i_test})) == (4, 50) |
| 550 | + assert y.eval({i: i_test}).shape == (4, 50) |
| 551 | + |
| 552 | + x_new = as_tensor_variable(100) |
| 553 | + i_test = np.zeros((4, 100), dtype=config.floatX) |
| 554 | + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) |
| 555 | + assert y_new.type.shape == (None, 100) |
| 556 | + assert tuple(y_new.shape.eval({i: i_test})) == (4, 100) |
| 557 | + assert y_new.eval({i: i_test}).shape == (4, 100) |
| 558 | + |
527 | 559 |
|
528 | 560 | class TestSpecifyBroadcastable: |
529 | 561 | def test_basic(self): |
|
0 commit comments