Skip to content

Conversation

@Sea-Snell
Copy link

Motivation: Sharded pytorch checkpoints cannot currently be loaded into flax models; this may be desirable in some cases (e.g. "google/ul2").

Changes: I added an few lines to modeling_flax_utils.py to support this behavior. The behavior of the added code exactly matches how sharded checkpoints are loaded in modeling_utils.py for pytorch models.

@patrickvonplaten, @patil-suraj

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your PR! You'll need to import this constant from the utils submodule however :-)

@Sea-Snell
Copy link
Author

Oops! Thanks, just added that import.

@sgugger
Copy link
Collaborator

sgugger commented Jul 21, 2022

Now you'll need to run make style to fix the formatting issues :-)

@Sea-Snell
Copy link
Author

done!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker could you also have a quick look?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing, I would be more in favour of finalising #18026 or you can merge my branch.
Overall we should always test both locally and on the hub 😄

Comment on lines +643 to +646
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just wondering if could add a small test?
You can use hf-internal-testing/tiny-random-bert-sharded/.
Also I opened #18026 which is really similar, which adds

    @is_pt_flax_cross_test
    def test_from_sharded_pt(self):
        model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
        ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only")
        for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
            assert np.allclose(np.array(p1), np.array(p2))

Was not really aware that the conversion would be straight forward let me have a look

if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW this will only work if the WEIGHTS_INDEX_NAME file is locally present, and does not include the hub.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's just finalize yours. What's left to do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just fixing the tests, and making sure that the tests are actually good. Should be quiet straightforward☺️🙌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sea-Snell we just need to fix test_from_sharded_pt which is failing because the model used for comparison are not the same! Simply using the same model (either upload a new model using the same config but shard it with save_pretrained and setting the max_shard_size to 150KB should do the trick.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants