Skip to content

Conversation

@bminixhofer
Copy link
Contributor

What does this PR do?

Use jax.random.permutation instead of np.random.permutation in the data_loader function of run_clm_flax.py to make it use the global seed. Currently batch order would probably vary across runs, regardless of the global seed.

Also changes np.arange to jnp.arange and np.array to jnp.array and removes the numpy import, although that would not be strictly necessary.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patil-suraj

@patil-suraj
Copy link
Contributor

Thanks a lot for the PR @bminixhofer, good catch!

However, initially, we actually used jax.random.permutation and jax.arange in the data loader but then switched to numpy as we observed it was causing some issues with JAX's asynchronous dispatch. Since JAX by default puts everything on the device it could cause some issues (especially on TPU) if the dataloader /collator is used with multiple threads to do background fetching. This also leads to major slowdowns. So all flax examples now don't use JAX functions in pre-processing, loading/collating etc. With this, the TPU can be busy all the time doing the actual computation and won't be blocked by processing and loading.

But I see the problem, so maybe we could use the seed with numpy to make data shuffling reproducible.

cc @patrickvonplaten .

@bminixhofer
Copy link
Contributor Author

Interesting, what about run_mlm_flax.py? I was having a look prior to this PR, and it seems jax.random.permutation is also used for data shuffling, or am I missing something?

num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)

@patil-suraj
Copy link
Contributor

That should also be changed. All flax examples are still not completely consistent with each other, which needs to be fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj - I think here we could actually switch to jax.random.permutation as this is called only once per epoch. I'm pretty sure that it wouldn't slow down the training

@bminixhofer
Copy link
Contributor Author

@patil-suraj @patrickvonplaten Is this something you want to change or should I close this PR?

@patrickvonplaten
Copy link
Contributor

Hey @bminixhofer - sorry for being so slow here. @patil-suraj I'm happy to merge the PR. Think it's good to have 100% reproduciblity with JAX's random seed. I don't think this slows down the script as it's called just once per epoch. If you're ok with the changes feel free to merge

@huggingface huggingface deleted a comment from github-actions bot Oct 22, 2021
@patrickvonplaten
Copy link
Contributor

@patil-suraj - can you take a look here and leave your opinion so that we can resolve the PR? :-)

@huggingface huggingface deleted a comment from github-actions bot Nov 16, 2021
@patrickvonplaten
Copy link
Contributor

ping @patil-suraj again

@huggingface huggingface deleted a comment from github-actions bot Dec 13, 2021
@patrickvonplaten
Copy link
Contributor

ping @patil-suraj again

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about being super slow here. Agree with @patrickvonplaten and think we can change this to use jax.random.
also, note that even with numpy batch order will be preserved since we set the seed for numpy as well.
cf

@bminixhofer Thank you for fixing this!

@patil-suraj patil-suraj merged commit 2a606f9 into huggingface:master Dec 14, 2021
@bminixhofer
Copy link
Contributor Author

Thanks for merging this! I vaguely remember having problems with batch order with the code as it was previously, but I am not completely sure (it's been some time 😅 ).

@bminixhofer bminixhofer deleted the data_loader_jax branch December 14, 2021 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants