Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,9 @@ def default_output(self):
if len(self.outputs) == 1:
return self.outputs[0]
else:
raise ValueError(f"{self.op}.default_output should be an output index.")
elif not isinstance(do, int):
raise ValueError(f"{self.op}.default_output should be an int or long")
elif do < 0 or do >= len(self.outputs):
raise ValueError(f"{self.op}.default_output is out of range.")
Comment on lines -198 to -200
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Python error messages do a good enough job here when indexing is not valid or out of bounds. No need to reinvent the wheel.

raise ValueError(
f"Multi-output Op {self.op} default_output not specified"
)
return self.outputs[do]

def __str__(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _as_tensor_Apply(x, name, ndim, **kwargs):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError(
"Multi-output Op encountered. "
"Multi-output Op without default_output encountered. "
"Retry using only one of the outputs directly."
)

Expand Down
31 changes: 21 additions & 10 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,13 @@ def test_infer_shape(self):


class ApplyDefaultTestOp(Op):
def __init__(self, id):
def __init__(self, id, n_outs=1):
self.default_output = id
self.n_outs = n_outs

def make_node(self, x):
x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
return Apply(self, [x], [x.type() for _ in range(self.n_outs)])

def perform(self, *args, **kwargs):
raise NotImplementedError()
Expand Down Expand Up @@ -556,16 +557,26 @@ def test_tensor_from_scalar(self):
y = as_tensor_variable(aes.int8())
assert isinstance(y.owner.op, TensorFromScalar)

def test_multi_outputs(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
as_tensor_variable(good_apply_var)
def test_default_output(self):
good_apply_var = ApplyDefaultTestOp(0, n_outs=1).make_node(self.x)
as_tensor_variable(good_apply_var) is good_apply_var

bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(ValueError):
good_apply_var = ApplyDefaultTestOp(-1, n_outs=1).make_node(self.x)
as_tensor_variable(good_apply_var) is good_apply_var

bad_apply_var = ApplyDefaultTestOp(1, n_outs=1).make_node(self.x)
with pytest.raises(IndexError):
_ = as_tensor_variable(bad_apply_var)

bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(ValueError):
bad_apply_var = ApplyDefaultTestOp(2.0, n_outs=1).make_node(self.x)
with pytest.raises(TypeError):
_ = as_tensor_variable(bad_apply_var)

good_apply_var = ApplyDefaultTestOp(1, n_outs=2).make_node(self.x)
as_tensor_variable(good_apply_var) is good_apply_var.outputs[1]

bad_apply_var = ApplyDefaultTestOp(None, n_outs=2).make_node(self.x)
with pytest.raises(TypeError, match="Multi-output Op without default_output"):
_ = as_tensor_variable(bad_apply_var)

def test_list(self):
Expand All @@ -578,7 +589,7 @@ def test_list(self):
_ = as_tensor_variable(y)

bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
as_tensor_variable(bad_apply_var)

def test_ndim_strip_leading_broadcastable(self):
Expand Down