-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Make data shuffling in run_clm_flax.py respect global seed
#13410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks a lot for the PR @bminixhofer, good catch! However, initially, we actually used But I see the problem, so maybe we could use the seed with numpy to make data shuffling reproducible. cc @patrickvonplaten . |
|
Interesting, what about transformers/examples/flax/language-modeling/run_mlm_flax.py Lines 624 to 626 in 76c4d8b
|
|
That should also be changed. All flax examples are still not completely consistent with each other, which needs to be fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj - I think here we could actually switch to jax.random.permutation as this is called only once per epoch. I'm pretty sure that it wouldn't slow down the training
51d0e82 to
3279d45
Compare
|
@patil-suraj @patrickvonplaten Is this something you want to change or should I close this PR? |
|
Hey @bminixhofer - sorry for being so slow here. @patil-suraj I'm happy to merge the PR. Think it's good to have 100% reproduciblity with JAX's random seed. I don't think this slows down the script as it's called just once per epoch. If you're ok with the changes feel free to merge |
|
@patil-suraj - can you take a look here and leave your opinion so that we can resolve the PR? :-) |
|
ping @patil-suraj again |
|
ping @patil-suraj again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about being super slow here. Agree with @patrickvonplaten and think we can change this to use jax.random.
also, note that even with numpy batch order will be preserved since we set the seed for numpy as well.
cf
| set_seed(training_args.seed) |
@bminixhofer Thank you for fixing this!
|
Thanks for merging this! I vaguely remember having problems with batch order with the code as it was previously, but I am not completely sure (it's been some time 😅 ). |
What does this PR do?
Use
jax.random.permutationinstead ofnp.random.permutationin thedata_loaderfunction ofrun_clm_flax.pyto make it use the global seed. Currently batch order would probably vary across runs, regardless of the global seed.Also changes
np.arangetojnp.arangeandnp.arraytojnp.arrayand removes the numpy import, although that would not be strictly necessary.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patil-suraj