-
Notifications
You must be signed in to change notification settings - Fork 13
Use barriers to reduce duplicate work/resources #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -245,17 +245,20 @@ def save_pretrained(self, save_directory, xla_device=False): | |
# Only save the model itself if we are using distributed training | ||
model_to_save = self.module if hasattr(self, 'module') else self | ||
|
||
# Save configuration file | ||
model_to_save.config.save_pretrained(save_directory) | ||
|
||
# If we save using the predefined names, we can load using `from_pretrained` | ||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | ||
|
||
# Save configuration file | ||
if xla_device: | ||
# Saving for each process is save since output_dir is proc# namespaced. | ||
import torch_xla.core.xla_model as xm | ||
xm.save(model_to_save.state_dict(), output_model_file, master_only=False) | ||
if xm.is_master_ordinal(): | ||
model_to_save.config.save_pretrained(save_directory) | ||
# xm.save takes care of saving only from master | ||
xm.save(model_to_save.state_dict(), output_model_file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idea for the future; you could use the checkpoint tagger later, if they mark which chpt is the best etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I know, they don't have such tagger but let me see if they do. Good idea we could even upstream if they don't. Thanks. |
||
else: | ||
model_to_save.config.save_pretrained(save_directory) | ||
torch.save(model_to_save.state_dict(), output_model_file) | ||
|
||
logger.info("Model weights saved in {}".format(output_model_file)) | ||
|
||
@classmethod | ||
|
@@ -371,8 +374,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
|
||
# redirect to the cache, if necessary | ||
try: | ||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, | ||
proxies=proxies, xla_device=xla_device) | ||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) | ||
except EnvironmentError: | ||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: | ||
msg = "Couldn't reach server at '{}' to download pretrained weights.".format( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming no tensor work is done in the methods b/w
download_only_once
rendezvous'sThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, correct.