|
7 | 7 | from pytensor.configdefaults import config
|
8 | 8 | from pytensor.graph.basic import Variable
|
9 | 9 | from pytensor.graph.fg import FunctionGraph
|
| 10 | +from pytensor.graph.replace import clone_replace |
10 | 11 | from pytensor.graph.type import Type
|
11 | 12 | from pytensor.misc.safe_asarray import _asarray
|
12 | 13 | from pytensor.scalar.basic import ScalarConstant
|
@@ -337,6 +338,21 @@ def test_more_shapes(self):
|
337 | 338 | Reshape,
|
338 | 339 | )
|
339 | 340 |
|
| 341 | + def test_rebuild(self): |
| 342 | + x = as_tensor_variable(50) |
| 343 | + i = vector("i") |
| 344 | + i_test = np.zeros((100,), dtype=config.floatX) |
| 345 | + y = reshape(i, (100 // x, x)) |
| 346 | + assert y.type.shape == (2, 50) |
| 347 | + assert tuple(y.shape.eval({i: i_test})) == (2, 50) |
| 348 | + assert y.eval({i: i_test}).shape == (2, 50) |
| 349 | + |
| 350 | + x_new = as_tensor_variable(25) |
| 351 | + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) |
| 352 | + assert y_new.type.shape == (4, 25) |
| 353 | + assert tuple(y_new.shape.eval({i: i_test})) == (4, 25) |
| 354 | + assert y_new.eval({i: i_test}).shape == (4, 25) |
| 355 | + |
340 | 356 |
|
341 | 357 | def test_shape_i_hash():
|
342 | 358 | assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
|
@@ -524,6 +540,22 @@ def test_specify_shape_in_grad(self):
|
524 | 540 | z_grad = grad(z.sum(), wrt=x)
|
525 | 541 | assert isinstance(z_grad.owner.op, SpecifyShape)
|
526 | 542 |
|
| 543 | + def test_rebuild(self): |
| 544 | + x = as_tensor_variable(50) |
| 545 | + i = matrix("i") |
| 546 | + i_test = np.zeros((4, 50), dtype=config.floatX) |
| 547 | + y = specify_shape(i, (None, x)) |
| 548 | + assert y.type.shape == (None, 50) |
| 549 | + assert tuple(y.shape.eval({i: i_test})) == (4, 50) |
| 550 | + assert y.eval({i: i_test}).shape == (4, 50) |
| 551 | + |
| 552 | + x_new = as_tensor_variable(100) |
| 553 | + i_test = np.zeros((4, 100), dtype=config.floatX) |
| 554 | + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) |
| 555 | + assert y_new.type.shape == (None, 100) |
| 556 | + assert tuple(y_new.shape.eval({i: i_test})) == (4, 100) |
| 557 | + assert y_new.eval({i: i_test}).shape == (4, 100) |
| 558 | + |
527 | 559 |
|
528 | 560 | class TestSpecifyBroadcastable:
|
529 | 561 | def test_basic(self):
|
|
0 commit comments