Skip to content

Commit 5da3d0c

Browse files
committed
Allow specializing shape of predefined tensors types
1 parent 3c66aa6 commit 5da3d0c

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pytensor/tensor/type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def parse_bcast_and_shape(s):
123123
self.name = name
124124
self.numpy_dtype = np.dtype(self.dtype)
125125

126+
def __call__(self, *args, shape=None, **kwargs):
127+
if shape is not None:
128+
# Check if shape is compatible with the original type
129+
new_type = self.clone(shape=shape)
130+
if self.is_super(new_type):
131+
return new_type(*args, **kwargs)
132+
else:
133+
raise ValueError(
134+
f"{shape=} is incompatible with original type shape {self.shape=}"
135+
)
136+
return super().__call__(*args, **kwargs)
137+
126138
def clone(
127139
self, dtype=None, shape=None, broadcastable=None, **kwargs
128140
) -> "TensorType":

tests/tensor/test_type.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from pytensor.tensor.type import (
1111
TensorType,
1212
col,
13+
dmatrix,
14+
drow,
15+
fmatrix,
16+
frow,
1317
matrix,
1418
row,
1519
scalar,
@@ -477,3 +481,19 @@ def test_row_matrix_creator_helpers(helper):
477481
match = "The second dimension of a `col` must have shape 1, got 5"
478482
with pytest.raises(ValueError, match=match):
479483
helper(shape=(2, 5))
484+
485+
486+
def test_shape_of_predefined_dtype_tensor():
487+
# matrix can be converted to a row
488+
assert fmatrix(shape=(1, None)).type == frow
489+
490+
# row can be specialized
491+
assert drow(shape=(1, 5)).type == dmatrix(shape=(1, 5)).type
492+
493+
with pytest.raises(ValueError):
494+
# matrix cannot be converted into a tensor3
495+
fmatrix(shape=(None, None, None))
496+
497+
with pytest.raises(ValueError):
498+
# row can't be converted to a generic matrix
499+
drow(shape=(None, None))

0 commit comments

Comments
 (0)