Skip to content

Conversation

ricardoV94
Copy link
Member

Description

The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
Copy link
Member Author

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise I suspect we were not really testing jax subtensor dispatch as the index operation would be just constant_folded. This was not always the case, as we used to not run any rewrites in compare_jax_and_py. Now we do

compare_jax_and_py(out_fg, [x_np])

# Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the case that led me to open this PR. JAX is happy to do it but our check wouldn't let it compile

@ricardoV94 ricardoV94 force-pushed the jax_eager_subtensor branch from c27898a to 6fc8f7a Compare June 24, 2024 11:57
The check was failing incorrectly for cases that are supported such as constant Boolean arrays.
Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
@ricardoV94 ricardoV94 force-pushed the jax_eager_subtensor branch from 6fc8f7a to 2b99224 Compare June 24, 2024 12:18
Copy link

codecov bot commented Jun 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.88%. Comparing base (d3bd1f1) to head (2b99224).
Report is 158 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #849      +/-   ##
==========================================
- Coverage   80.89%   80.88%   -0.01%     
==========================================
  Files         169      169              
  Lines       46979    46966      -13     
  Branches    11478    11472       -6     
==========================================
- Hits        38002    37989      -13     
  Misses       6764     6764              
  Partials     2213     2213              
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/subtensor.py 86.95% <100.00%> (-2.88%) ⬇️

@ricardoV94 ricardoV94 changed the title Remove false positive checks for supported Subtensors operations in JAX Remove conservative checks for supported Subtensors operations in JAX Jun 25, 2024
@lucianopaz lucianopaz merged commit 684a929 into pymc-devs:main Jun 28, 2024
@ricardoV94 ricardoV94 mentioned this pull request Jun 28, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants