Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Dec 14, 2021

What does this PR do?

run_summarization_flax.py has

decoder_input_ids = shift_tokens_right_fn(
jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
)

Using jnp.array here will cause preprocess_function to hang forever when it is used by datasets.Dataset.map() with num_proc > 1, when this script is running on a TPU VM.

I think it is related to #12719 and #12720

Who can review?

@patil-suraj @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

@patil-suraj - could you take this one? :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Thanks a lot for fixing this @ydshieh !

@patil-suraj patil-suraj merged commit a94105f into huggingface:master Dec 15, 2021
@ydshieh ydshieh deleted the fix_flax_summarization_example branch May 5, 2022 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants