-
Notifications
You must be signed in to change notification settings - Fork 146
Fix typing in subtensor module #823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a9798f2
to
2b77374
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #823 +/- ##
==========================================
- Coverage 80.99% 80.98% -0.02%
==========================================
Files 169 169
Lines 46939 46961 +22
Branches 11467 11478 +11
==========================================
+ Hits 38019 38032 +13
- Misses 6713 6714 +1
- Partials 2207 2215 +8
|
2b77374
to
6200b05
Compare
the coverage ❌ is due to the overload signatures - can't do anything about it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a look, have some doubts about a few of the changes, but the directions seems great.
In addition, if you want to remove the lines about sys.maxsize from the giant canonical slice function feel free to do it, as we are not using that for sure (and there's even an inline comment in one of its uses, saying so)
pytensor/tensor/subtensor.py
Outdated
else: | ||
sslice = vlit | ||
except NotScalarConstantError: | ||
raise ValueError(f"Slice {theslice} is not a supported slice type.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't it be a slice with symbolic start/step/stop at this point?
From reading the code it seems you changed the meaning of the function to fail with certain symbolic inputs, such as slice(pt.lscalar("start"), pt.lscalar("stop"), pt.lscalar("step"))
or a scalar indexing pt.lscalar("idx")
, where it was accommodating them before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay this function does not expect symbolic slices (or at least it was never testing them), so the new changes shoudn't either?
If it's not a slice it must be a scalar index, and you shouldn't have to worry about the slice case showing up again here, the old code was certainly not worried about that case. If it should handle symbolic slices (i.e, SliceVariable) types, then we also need tests for those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wasn't testing them, but as_index_literal
already took them apart into a native slice
:
pytensor/pytensor/tensor/subtensor.py
Line 172 in efa845a
idx = slice(*idx.owner.inputs) |
Consequently, get_canonical_form_slice
handled them already..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That still means this function only needs to handle python slices, and the inner code can be simplified
pytensor/tensor/subtensor.py
Outdated
if vlit is None: | ||
return slice(0, length, 1), 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure? Isn't None
a newaxis
, which shouldn't correspond to any slice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also unclear whether this function is expected to handle None. Like SliceVariable, those only show up in the index for AdvancedSubtensor Ops. For basic Subtensor Ops, newaxis are eagerly converted to expand_dims
, and never show as inputs of the Op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the docstring I would say the function is not supposed to do None → None
things.
There were no tests for that either, and we would create a bigger mess by hinting a possibly None
return type.
Actually, I think we should raise ValueError
in this situation.
Regarding the graphs being complicated, there's a related issue #112 |
pytensor/tensor/subtensor.py
Outdated
value = theslice | ||
|
||
value = switch(lt(value, 0), (value + length), value) | ||
# Try to extract a scalar slice and return it already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch may be more clean if it clearly tries to not handle slices, by checking I don't know if not isinstance(theslice, slice | SliceVariable)
?newaxis
are supposed to make it to this function, I had a question below about it.
And I would call these scalar indexing instead of slices, not scalar slices?
Then after this branch you can force the SliceVariable to be a normal slice by doing as_index_literal
or slice(*sslice.owner.inputs)
Edit: After reading the pre-existig code and tests I am not sure this function is supposed to handle SliceVariables, so ignore that part of the remark
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can do that, but I don't think we should "advertise" it in the type hints. It's covered by the Variable
hint already.
tests/tensor/test_subtensor.py
Outdated
|
||
|
||
class TestGetCanonicalFormSlice: | ||
def test_none_constant(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very unsure about this. None != slice(None)
in indexing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I changed this into raising a ValueError
, because addition of a new axis something other than canonical slicing.
See comment above too.
tests/tensor/test_subtensor.py
Outdated
def test_symbolic_slice(self): | ||
idx = make_slice(slice(3, 7, 2)) | ||
assert not isinstance(idx, slice) | ||
res = get_canonical_form_slice(idx, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick suggestion for these kinds of tests
res = get_canonical_form_slice(idx, 10) | |
canonical_slice, direction = get_canonical_form_slice(idx, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too scared of black
's line wrapping wrath
tests/tensor/test_subtensor.py
Outdated
assert res[1] == 1 | ||
|
||
def test_symbolic_slice(self): | ||
idx = make_slice(slice(3, 7, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the part I was confused. I thought the function was meant to handle these types originally, but after more careful reading, don't think it was?
If it was (or became) meant to handle them, then it should also handle make_slice
, with symbolic inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it can handle variable and constant symbolic slices too 😎
pytensor/tensor/subtensor.py
Outdated
|
||
|
||
@overload | ||
def as_index_literal(idx: Constant) -> int | np.integer | None: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can be more specific by using NoneConst
, SliceConst
, Tensor|ScalarConstant
and SliceVariable
, Tensor|ScalarVariable
in the overloaded signatures
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorContants are already a subclass of TensorVariable (so you don't need that distinction at the signature level), but the other types ScalarConstant | SliceConstant
may not be. You can make them if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the PR where we made some of those subclasses: #628
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for point out that these types exist. With them, I was able to simplify the function, and among other things avoid the getattr
too.
I will not reply to your other comments on as_index_literal
, since they no longer apply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ricardoV94, your suggestions helped to further simplify the code.
On my machine, these changes accelerated the test_subtensor.py
suite by almost 10 seconds, even though new test cases were added(!)
pytensor/tensor/subtensor.py
Outdated
|
||
|
||
@overload | ||
def as_index_literal(idx: Constant) -> int | np.integer | None: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for point out that these types exist. With them, I was able to simplify the function, and among other things avoid the getattr
too.
I will not reply to your other comments on as_index_literal
, since they no longer apply.
pytensor/tensor/subtensor.py
Outdated
value = theslice | ||
|
||
value = switch(lt(value, 0), (value + length), value) | ||
# Try to extract a scalar slice and return it already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can do that, but I don't think we should "advertise" it in the type hints. It's covered by the Variable
hint already.
pytensor/tensor/subtensor.py
Outdated
if vlit is None: | ||
return slice(0, length, 1), 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the docstring I would say the function is not supposed to do None → None
things.
There were no tests for that either, and we would create a bigger mess by hinting a possibly None
return type.
Actually, I think we should raise ValueError
in this situation.
pytensor/tensor/subtensor.py
Outdated
else: | ||
sslice = vlit | ||
except NotScalarConstantError: | ||
raise ValueError(f"Slice {theslice} is not a supported slice type.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wasn't testing them, but as_index_literal
already took them apart into a native slice
:
pytensor/pytensor/tensor/subtensor.py
Line 172 in efa845a
idx = slice(*idx.owner.inputs) |
Consequently, get_canonical_form_slice
handled them already..
tests/tensor/test_subtensor.py
Outdated
|
||
|
||
class TestGetCanonicalFormSlice: | ||
def test_none_constant(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I changed this into raising a ValueError
, because addition of a new axis something other than canonical slicing.
See comment above too.
tests/tensor/test_subtensor.py
Outdated
assert res[1] == 1 | ||
|
||
def test_symbolic_slice(self): | ||
idx = make_slice(slice(3, 7, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it can handle variable and constant symbolic slices too 😎
tests/tensor/test_subtensor.py
Outdated
def test_symbolic_slice(self): | ||
idx = make_slice(slice(3, 7, 2)) | ||
assert not isinstance(idx, slice) | ||
res = get_canonical_form_slice(idx, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too scared of black
's line wrapping wrath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking cleaner, but I still think we can simplify, because we are doing more now than the function was doing before
pytensor/tensor/subtensor.py
Outdated
) | ||
|
||
if not isinstance(idx, Variable): | ||
raise NotScalarConstantError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotScalarConstantError() | |
raise TypeError() |
pytensor/tensor/subtensor.py
Outdated
|
||
@overload | ||
def get_canonical_form_slice( | ||
theslice: slice | SliceConstant, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not just Constants right?
theslice: slice | SliceConstant, | |
theslice: slice | SliceVariable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but there's no SliceVariable
type at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh strange. So there are Variables with type SliceType but no specific class like TensorVariable or ScalarVariable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly
pytensor/tensor/subtensor.py
Outdated
def get_canonical_form_slice( | ||
theslice: slice | Variable, length: Variable | ||
) -> tuple[Variable, int]: | ||
theslice: slice | Variable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the union of the specific types we overloaded above? Variable
allows more than SliceVariable | ScalarVariable | TensorVariable
but our function does not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Variable
input covers the slice-variable for which we don't have a dedicated type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should create one then
pytensor/tensor/subtensor.py
Outdated
|
||
value = switch(lt(value, 0), (value + length), value) | ||
# Convert the two symbolic slice types into a native slice | ||
if isinstance(theslice, SliceConstant): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think this function is not supposed to handle non-python slices (be it constant or not).
That allows us to simplify our logic and type hints quite a lot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as symbolic slices exist, I think this function should support them. How else should they be handled?
For now, I'd just like to get the typing fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function does not seem to be intended for SliceTypes. I wrote elsewhere but there are two ways of representing slices in PyTensor: basic index operations don't have explicit slice inputs, they just take as inputs the start/stop/step variables that define a slice.
Advanced indexing operations instead have explicit slice variables of SliceVariable type as inputs. This function was seemingly not meant to work with those types and unless that's why you went into this refactor there's no reason to extend its behavior to handle the slice types that are ony used by advanced indexing operations.
Does that make sense?
pytensor/tensor/subtensor.py
Outdated
sslice = theslice | ||
elif isinstance(theslice, TensorVariable) and theslice.ndim == 0: | ||
sslice = theslice | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just raise a TypeError, no other types should show up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are almost a hundred places where NotScalarConstantError
is caught. Replacing it with a generic TypeError
appears very risky..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's being misused here, why do you want to raise it?
pytensor/tensor/subtensor.py
Outdated
else: | ||
sslice = vlit | ||
except NotScalarConstantError: | ||
raise ValueError(f"Slice {theslice} is not a supported slice type.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That still means this function only needs to handle python slices, and the inner code can be simplified
tests/tensor/test_subtensor.py
Outdated
|
||
def test_constant_slice(self): | ||
idx = as_symbolic_slice(slice(3, 7, 2)) | ||
assert isinstance(idx, SliceConstant) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to support non python slices. The code before wasn't, so there is no specified need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there symbolic slices to begin with?
If we don't need them, we should remove them. If we need them, I think the functions dealing with slices should deal with them too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replied here: #823 (comment)
They are used in a different class of Ops
67c8df9
to
acd537b
Compare
Thanks @michaelosthege |
Description
I started fixing type hints here and there, and then focused on the
pytensor.tensor.subtensor
module.It was tricky, but by adding some overloads (to help mypy) and refactoring one function, I managed to clear errors in that file.
The only remaining error is the
Cannot determine type
thing for an imported dispatched function.Checklist
Type of change