diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 0a0c867f82..ccf9dd90de 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -22,6 +22,9 @@ from pytensor.tensor.var import TensorConstant, TensorVariable +ShapeValueType = Union[None, np.integer, int, Variable] + + def register_shape_c_code(type, code, version=()): """ Tell Shape Op how to generate C code for an PyTensor Type. @@ -541,9 +544,7 @@ def c_code_cache_version(self): def specify_shape( x: Union[np.ndarray, Number, Variable], - shape: Union[ - int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable - ], + shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]], ): """Specify a fixed shape for a `Variable`.