|
6 | 6 |
|
7 | 7 | import pytensor |
8 | 8 | import tests.unittest_tools as utt |
| 9 | +from pytensor.compile import DeepCopyOp |
9 | 10 | from pytensor.compile.mode import get_default_mode |
10 | 11 | from pytensor.graph.basic import Constant, equal_computations |
11 | 12 | from pytensor.tensor import get_vector_length |
12 | 13 | from pytensor.tensor.basic import constant |
13 | 14 | from pytensor.tensor.elemwise import DimShuffle |
14 | 15 | from pytensor.tensor.math import dot, eq |
| 16 | +from pytensor.tensor.shape import Shape |
15 | 17 | from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor |
16 | 18 | from pytensor.tensor.type import ( |
17 | 19 | TensorType, |
@@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order): |
245 | 247 |
|
246 | 248 | def test_fixed_shape_variable_basic(): |
247 | 249 | x = TensorVariable(TensorType("int64", shape=(4,)), None) |
248 | | - assert isinstance(x.shape, Constant) |
249 | | - assert np.array_equal(x.shape.data, (4,)) |
| 250 | + assert x.type.shape == (4,) |
| 251 | + assert isinstance(x.shape.owner.op, Shape) |
| 252 | + |
| 253 | + shape_fn = pytensor.function([x], x.shape) |
| 254 | + opt_shape = shape_fn.maker.fgraph.outputs[0] |
| 255 | + assert isinstance(opt_shape.owner.op, DeepCopyOp) |
| 256 | + assert isinstance(opt_shape.owner.inputs[0], Constant) |
| 257 | + assert np.array_equal(opt_shape.owner.inputs[0].data, (4,)) |
250 | 258 |
|
251 | 259 | x = TensorConstant( |
252 | 260 | TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]]) |
|
0 commit comments