Skip to content

Conversation

aseyboldt
Copy link
Member

Compile time for DimShuffle ops is pretty long:

import pytensor
import pytensor.tensor as pt
import numpy as np
import numba

z = pt.dscalar("z")
out = z[None, None]
z_val = np.array(0.1)

%%time
func = pytensor.function([z], out, mode="NUMBA")
func(z_val)


# Before
CPU times: user 1.6 s, sys: 23.5 ms, total: 1.62 s
Wall time: 1.62 s

# After
CPU times: user 661 ms, sys: 36.1 ms, total: 697 ms
Wall time: 697 ms

Still ways to go, and unfortunately it doesn't seem to have as big an impact on the compile times for larger models as I hoped, but it is a start...

Going forward I think we should try to find more ops like this, where individually the compile time is large, and try some other things:

  • Do we really need to ask for O3 optimization all the time? I guess O2 might be enough for most cases
  • We recreate the numba functions all the time, which means that llvm will potentially see functions multiple times, and will also have to optimize those multiple times. Could we move a lot of njit functions out of the dispatch function? I think this also prevents numba caching from working properly.
  • There are a lot of inline="always" functions still around. Maybe we want to get rid of at least some of those?

@aseyboldt aseyboldt force-pushed the compile-time-dimshuffle branch from b753226 to deb6535 Compare December 8, 2022 23:29
@codecov-commenter
Copy link

codecov-commenter commented Dec 9, 2022

Codecov Report

Merging #95 (5ccb1f2) into main (491f93e) will increase coverage by 0.18%.
The diff coverage is 86.11%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #95      +/-   ##
==========================================
+ Coverage   74.22%   74.41%   +0.18%     
==========================================
  Files         174      179       +5     
  Lines       48734    49249     +515     
  Branches    10367    10422      +55     
==========================================
+ Hits        36175    36649     +474     
- Misses      10272    10295      +23     
- Partials     2287     2305      +18     
Impacted Files Coverage Δ
pytensor/misc/ordered_set.py 79.80% <ø> (ø)
pytensor/tensor/nnet/corr.py 16.81% <0.00%> (ø)
pytensor/link/numba/dispatch/extra_ops.py 92.24% <70.21%> (-5.77%) ⬇️
pytensor/link/numba/dispatch/scalar.py 94.44% <75.00%> (+7.02%) ⬆️
pytensor/link/numba/dispatch/cython_support.py 86.95% <86.95%> (ø)
pytensor/link/numba/dispatch/basic.py 90.06% <95.83%> (-2.62%) ⬇️
pytensor/link/numba/dispatch/elemwise.py 97.04% <97.87%> (-0.09%) ⬇️
pytensor/graph/basic.py 88.10% <100.00%> (+0.43%) ⬆️
pytensor/link/numba/dispatch/nlinalg.py 100.00% <100.00%> (ø)
pytensor/sparse/sandbox/sp.py 73.48% <100.00%> (ø)
... and 25 more

else:
new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
def find_shape(array_shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we shortcut when the output static shape is known?

pt.random.normal(size=(2, 3)).dimshuffle(1, 0).type.shape  # (3, 2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that might save a tiny bit of runtime and/or compile time, but if that were incorrect for some reason we would end up writing to arrays out of bounds. There is no error checking after this point.
I think it is safer to always infer the output shape from the inputs...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case the graph would be inconsistent and should fail anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should fail. But if we just assume the shapes are correct, we might just silently corrupt memory and not fail in an obvious way. :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry I somehow though this was the shape code for the llvm-elemwise...
Here the reshape should just fail if we provide something incorrect. So yes, I think we can use extra info we have here, I'll update the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way it's probably a minor performance difference. I was more thinking out loud in what types of place can we benefit from static shape info.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

What exactly provides the compilation speedup?

@aseyboldt
Copy link
Member Author

I'm still trying to understand numba compile times better, but it seems in this case the transpose had quite some impact. And most of the time the transpose in DimShuffle is just a no-op (and we know it is), so we can remove it in those cases.
I also changed the code for the final shape so it involves fewer function calls that need to be analysed and typed.

@aseyboldt aseyboldt merged commit d9fe197 into pymc-devs:main Dec 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants