Skip to content

Conversation

michaelosthege
Copy link
Member

Description

I started fixing type hints here and there, and then focused on the pytensor.tensor.subtensor module.

It was tricky, but by adding some overloads (to help mypy) and refactoring one function, I managed to clear errors in that file.

The only remaining error is the Cannot determine type thing for an imported dispatched function.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov
Copy link

codecov bot commented Jun 16, 2024

Codecov Report

Attention: Patch coverage is 82.00000% with 9 lines in your changes missing coverage. Please review.

Project coverage is 80.98%. Comparing base (7a00b88) to head (acd537b).
Report is 152 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/subtensor.py 79.54% 1 Missing and 8 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #823      +/-   ##
==========================================
- Coverage   80.99%   80.98%   -0.02%     
==========================================
  Files         169      169              
  Lines       46939    46961      +22     
  Branches    11467    11478      +11     
==========================================
+ Hits        38019    38032      +13     
- Misses       6713     6714       +1     
- Partials     2207     2215       +8     
Files with missing lines Coverage Δ
pytensor/graph/basic.py 88.61% <ø> (ø)
pytensor/tensor/random/op.py 93.78% <100.00%> (ø)
pytensor/tensor/random/utils.py 100.00% <100.00%> (ø)
pytensor/tensor/shape.py 92.68% <100.00%> (ø)
pytensor/tensor/type.py 94.54% <100.00%> (ø)
pytensor/tensor/subtensor.py 89.27% <79.54%> (-0.62%) ⬇️

@michaelosthege
Copy link
Member Author

the coverage ❌ is due to the overload signatures - can't do anything about it

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a look, have some doubts about a few of the changes, but the directions seems great.

In addition, if you want to remove the lines about sys.maxsize from the giant canonical slice function feel free to do it, as we are not using that for sure (and there's even an inline comment in one of its uses, saying so)

else:
sslice = vlit
except NotScalarConstantError:
raise ValueError(f"Slice {theslice} is not a supported slice type.")
Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't it be a slice with symbolic start/step/stop at this point?

From reading the code it seems you changed the meaning of the function to fail with certain symbolic inputs, such as slice(pt.lscalar("start"), pt.lscalar("stop"), pt.lscalar("step")) or a scalar indexing pt.lscalar("idx"), where it was accommodating them before

Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this function does not expect symbolic slices (or at least it was never testing them), so the new changes shoudn't either?

If it's not a slice it must be a scalar index, and you shouldn't have to worry about the slice case showing up again here, the old code was certainly not worried about that case. If it should handle symbolic slices (i.e, SliceVariable) types, then we also need tests for those.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't testing them, but as_index_literal already took them apart into a native slice:

idx = slice(*idx.owner.inputs)

Consequently, get_canonical_form_slice handled them already..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That still means this function only needs to handle python slices, and the inner code can be simplified

Comment on lines 250 to 251
if vlit is None:
return slice(0, length, 1), 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? Isn't None a newaxis, which shouldn't correspond to any slice?

Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also unclear whether this function is expected to handle None. Like SliceVariable, those only show up in the index for AdvancedSubtensor Ops. For basic Subtensor Ops, newaxis are eagerly converted to expand_dims, and never show as inputs of the Op

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docstring I would say the function is not supposed to do None → None things.
There were no tests for that either, and we would create a bigger mess by hinting a possibly None return type.

Actually, I think we should raise ValueError in this situation.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 17, 2024

Regarding the graphs being complicated, there's a related issue #112

value = theslice

value = switch(lt(value, 0), (value + length), value)
# Try to extract a scalar slice and return it already.
Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch may be more clean if it clearly tries to not handle slices, by checking not isinstance(theslice, slice | SliceVariable)? I don't know if newaxis are supposed to make it to this function, I had a question below about it.

And I would call these scalar indexing instead of slices, not scalar slices?

Then after this branch you can force the SliceVariable to be a normal slice by doing as_index_literal or slice(*sslice.owner.inputs)

Edit: After reading the pre-existig code and tests I am not sure this function is supposed to handle SliceVariables, so ignore that part of the remark

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can do that, but I don't think we should "advertise" it in the type hints. It's covered by the Variable hint already.



class TestGetCanonicalFormSlice:
def test_none_constant(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very unsure about this. None != slice(None) in indexing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I changed this into raising a ValueError, because addition of a new axis something other than canonical slicing.

See comment above too.

def test_symbolic_slice(self):
idx = make_slice(slice(3, 7, 2))
assert not isinstance(idx, slice)
res = get_canonical_form_slice(idx, 10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick suggestion for these kinds of tests

Suggested change
res = get_canonical_form_slice(idx, 10)
canonical_slice, direction = get_canonical_form_slice(idx, 10)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too scared of black's line wrapping wrath

assert res[1] == 1

def test_symbolic_slice(self):
idx = make_slice(slice(3, 7, 2))
Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part I was confused. I thought the function was meant to handle these types originally, but after more careful reading, don't think it was?

If it was (or became) meant to handle them, then it should also handle make_slice, with symbolic inputs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it can handle variable and constant symbolic slices too 😎



@overload
def as_index_literal(idx: Constant) -> int | np.integer | None: ...
Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can be more specific by using NoneConst, SliceConst, Tensor|ScalarConstant and SliceVariable, Tensor|ScalarVariable in the overloaded signatures

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorContants are already a subclass of TensorVariable (so you don't need that distinction at the signature level), but the other types ScalarConstant | SliceConstant may not be. You can make them if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the PR where we made some of those subclasses: #628

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point out that these types exist. With them, I was able to simplify the function, and among other things avoid the getattr too.

I will not reply to your other comments on as_index_literal, since they no longer apply.

Copy link
Member Author

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94, your suggestions helped to further simplify the code.

On my machine, these changes accelerated the test_subtensor.py suite by almost 10 seconds, even though new test cases were added(!)



@overload
def as_index_literal(idx: Constant) -> int | np.integer | None: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point out that these types exist. With them, I was able to simplify the function, and among other things avoid the getattr too.

I will not reply to your other comments on as_index_literal, since they no longer apply.

value = theslice

value = switch(lt(value, 0), (value + length), value)
# Try to extract a scalar slice and return it already.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can do that, but I don't think we should "advertise" it in the type hints. It's covered by the Variable hint already.

Comment on lines 250 to 251
if vlit is None:
return slice(0, length, 1), 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docstring I would say the function is not supposed to do None → None things.
There were no tests for that either, and we would create a bigger mess by hinting a possibly None return type.

Actually, I think we should raise ValueError in this situation.

else:
sslice = vlit
except NotScalarConstantError:
raise ValueError(f"Slice {theslice} is not a supported slice type.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't testing them, but as_index_literal already took them apart into a native slice:

idx = slice(*idx.owner.inputs)

Consequently, get_canonical_form_slice handled them already..



class TestGetCanonicalFormSlice:
def test_none_constant(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I changed this into raising a ValueError, because addition of a new axis something other than canonical slicing.

See comment above too.

assert res[1] == 1

def test_symbolic_slice(self):
idx = make_slice(slice(3, 7, 2))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it can handle variable and constant symbolic slices too 😎

def test_symbolic_slice(self):
idx = make_slice(slice(3, 7, 2))
assert not isinstance(idx, slice)
res = get_canonical_form_slice(idx, 10)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too scared of black's line wrapping wrath

@michaelosthege michaelosthege changed the title Fiy typing in subtensor module Fix typing in subtensor module Jun 19, 2024
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking cleaner, but I still think we can simplify, because we are doing more now than the function was doing before

)

if not isinstance(idx, Variable):
raise NotScalarConstantError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotScalarConstantError()
raise TypeError()


@overload
def get_canonical_form_slice(
theslice: slice | SliceConstant,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just Constants right?

Suggested change
theslice: slice | SliceConstant,
theslice: slice | SliceVariable,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but there's no SliceVariable type at the moment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh strange. So there are Variables with type SliceType but no specific class like TensorVariable or ScalarVariable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly

def get_canonical_form_slice(
theslice: slice | Variable, length: Variable
) -> tuple[Variable, int]:
theslice: slice | Variable,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the union of the specific types we overloaded above? Variable allows more than SliceVariable | ScalarVariable | TensorVariable but our function does not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Variable input covers the slice-variable for which we don't have a dedicated type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should create one then


value = switch(lt(value, 0), (value + length), value)
# Convert the two symbolic slice types into a native slice
if isinstance(theslice, SliceConstant):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this function is not supposed to handle non-python slices (be it constant or not).

That allows us to simplify our logic and type hints quite a lot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as symbolic slices exist, I think this function should support them. How else should they be handled?

For now, I'd just like to get the typing fixed.

Copy link
Member

@ricardoV94 ricardoV94 Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not seem to be intended for SliceTypes. I wrote elsewhere but there are two ways of representing slices in PyTensor: basic index operations don't have explicit slice inputs, they just take as inputs the start/stop/step variables that define a slice.

Advanced indexing operations instead have explicit slice variables of SliceVariable type as inputs. This function was seemingly not meant to work with those types and unless that's why you went into this refactor there's no reason to extend its behavior to handle the slice types that are ony used by advanced indexing operations.

Does that make sense?

sslice = theslice
elif isinstance(theslice, TensorVariable) and theslice.ndim == 0:
sslice = theslice
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just raise a TypeError, no other types should show up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are almost a hundred places where NotScalarConstantError is caught. Replacing it with a generic TypeError appears very risky..

Copy link
Member

@ricardoV94 ricardoV94 Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's being misused here, why do you want to raise it?

else:
sslice = vlit
except NotScalarConstantError:
raise ValueError(f"Slice {theslice} is not a supported slice type.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That still means this function only needs to handle python slices, and the inner code can be simplified


def test_constant_slice(self):
idx = as_symbolic_slice(slice(3, 7, 2))
assert isinstance(idx, SliceConstant)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to support non python slices. The code before wasn't, so there is no specified need?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there symbolic slices to begin with?

If we don't need them, we should remove them. If we need them, I think the functions dealing with slices should deal with them too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied here: #823 (comment)

They are used in a different class of Ops

@ricardoV94 ricardoV94 merged commit f9dfe70 into pymc-devs:main Jun 25, 2024
@ricardoV94
Copy link
Member

Thanks @michaelosthege

@michaelosthege michaelosthege deleted the subtensor-typing branch June 25, 2024 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants