-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Allow loading pretrained shared Pytorch checkpoints into flax models #18170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for your PR! You'll need to import this constant from the utils submodule however :-)
|
Oops! Thanks, just added that import. |
|
Now you'll need to run |
|
done! |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker could you also have a quick look?
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing, I would be more in favour of finalising #18026 or you can merge my branch.
Overall we should always test both locally and on the hub 😄
| elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): | ||
| # Load from a sharded PyTorch checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) | ||
| is_sharded = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just wondering if could add a small test?
You can use hf-internal-testing/tiny-random-bert-sharded/.
Also I opened #18026 which is really similar, which adds
@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only")
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
assert np.allclose(np.array(p1), np.array(p2))Was not really aware that the conversion would be straight forward let me have a look
| if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): | ||
| # Load from a PyTorch checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
| elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW this will only work if the WEIGHTS_INDEX_NAME file is locally present, and does not include the hub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's just finalize yours. What's left to do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just fixing the tests, and making sure that the tests are actually good. Should be quiet straightforward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sea-Snell we just need to fix test_from_sharded_pt which is failing because the model used for comparison are not the same! Simply using the same model (either upload a new model using the same config but shard it with save_pretrained and setting the max_shard_size to 150KB should do the trick.
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Motivation: Sharded pytorch checkpoints cannot currently be loaded into flax models; this may be desirable in some cases (e.g. "google/ul2").
Changes: I added an few lines to
modeling_flax_utils.pyto support this behavior. The behavior of the added code exactly matches how sharded checkpoints are loaded inmodeling_utils.pyfor pytorch models.@patrickvonplaten, @patil-suraj