-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Safetensors serialization by default #27064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aa973a4
208ca04
0583271
e4ebba7
ddacff6
b6c1d2c
2fcda30
522ced8
b5699ba
42534ea
4c09b62
3ae3c62
1bd29e0
a79600f
c979b85
a9636e2
fa60654
db20c18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,8 @@ | |
| from .utils import ( | ||
| FLAX_WEIGHTS_INDEX_NAME, | ||
| FLAX_WEIGHTS_NAME, | ||
| SAFE_WEIGHTS_INDEX_NAME, | ||
| SAFE_WEIGHTS_NAME, | ||
| WEIGHTS_INDEX_NAME, | ||
| WEIGHTS_NAME, | ||
| PushToHubMixin, | ||
|
|
@@ -54,8 +56,14 @@ | |
| replace_return_docstrings, | ||
| ) | ||
| from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files | ||
| from .utils.import_utils import is_safetensors_available | ||
|
|
||
|
|
||
| if is_safetensors_available(): | ||
| from safetensors import safe_open | ||
| from safetensors.flax import load_file as safe_load_file | ||
| from safetensors.flax import save_file as safe_save_file | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -422,6 +430,31 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
| ```""" | ||
| return self._cast_floating_to(params, jnp.float16, mask) | ||
|
|
||
| @classmethod | ||
| def load_flax_weights(cls, resolved_archive_file): | ||
| try: | ||
| if resolved_archive_file.endswith(".safetensors"): | ||
| state = safe_load_file(resolved_archive_file) | ||
| state = unflatten_dict(state, sep=".") | ||
| else: | ||
| with open(resolved_archive_file, "rb") as state_f: | ||
| state = from_bytes(cls, state_f.read()) | ||
| except (UnpicklingError, msgpack.exceptions.ExtraData) as e: | ||
| try: | ||
| with open(resolved_archive_file) as f: | ||
| if f.read().startswith("version"): | ||
| raise OSError( | ||
| "You seem to have cloned a repository without having git-lfs installed. Please" | ||
| " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | ||
| " folder you cloned." | ||
| ) | ||
| else: | ||
| raise ValueError from e | ||
| except (UnicodeDecodeError, ValueError): | ||
| raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") | ||
|
|
||
| return state | ||
|
|
||
| @classmethod | ||
| def load_flax_sharded_weights(cls, shard_files): | ||
| """ | ||
|
|
@@ -687,7 +720,12 @@ def from_pretrained( | |
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||
| is_local = os.path.isdir(pretrained_model_name_or_path) | ||
| if os.path.isdir(pretrained_model_name_or_path): | ||
| if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): | ||
| if is_safetensors_available() and os.path.isfile( | ||
| os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) | ||
| ): | ||
| # Load from a safetensors checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) | ||
| elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): | ||
| # Load from a PyTorch checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) | ||
| elif from_pt and os.path.isfile( | ||
|
|
@@ -704,6 +742,13 @@ def from_pretrained( | |
| archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) | ||
| is_sharded = True | ||
| # At this stage we don't have a weight file so we will raise an error. | ||
| elif is_safetensors_available() and os.path.isfile( | ||
| os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) | ||
| ): | ||
| # Load from a sharded safetensors checkpoint | ||
| archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) | ||
| is_sharded = True | ||
| raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") | ||
| elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): | ||
| raise EnvironmentError( | ||
| f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " | ||
|
|
@@ -722,7 +767,13 @@ def from_pretrained( | |
| filename = pretrained_model_name_or_path | ||
| resolved_archive_file = download_url(pretrained_model_name_or_path) | ||
| else: | ||
| filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME | ||
| if from_pt: | ||
| filename = WEIGHTS_NAME | ||
| elif is_safetensors_available(): | ||
| filename = SAFE_WEIGHTS_NAME | ||
| else: | ||
| filename = FLAX_WEIGHTS_NAME | ||
|
|
||
| try: | ||
| # Load from URL or cache if already cached | ||
| cached_file_kwargs = { | ||
|
|
@@ -740,8 +791,15 @@ def from_pretrained( | |
| } | ||
| resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) | ||
|
|
||
| # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None | ||
| # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None | ||
| # result when internet is up, the repo and revision exist, but the file does not. | ||
| if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: | ||
| # Did not find the safetensors file, let's fallback to Flax. | ||
| # No support for sharded safetensors yet, so we'll raise an error if that's all we find. | ||
| filename = FLAX_WEIGHTS_NAME | ||
| resolved_archive_file = cached_file( | ||
| pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs | ||
| ) | ||
| if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: | ||
| # Maybe the checkpoint is sharded, we try to grab the index name in this case. | ||
| resolved_archive_file = cached_file( | ||
|
|
@@ -750,21 +808,26 @@ def from_pretrained( | |
| if resolved_archive_file is not None: | ||
| is_sharded = True | ||
| # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. | ||
| elif resolved_archive_file is None and from_pt: | ||
| if resolved_archive_file is None and from_pt: | ||
| resolved_archive_file = cached_file( | ||
| pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs | ||
| ) | ||
| if resolved_archive_file is not None: | ||
| is_sharded = True | ||
| if resolved_archive_file is None: | ||
| # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error | ||
| # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error | ||
| # message. | ||
| has_file_kwargs = { | ||
| "revision": revision, | ||
| "proxies": proxies, | ||
| "token": token, | ||
| } | ||
| if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): | ||
| if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): | ||
| is_sharded = True | ||
| raise NotImplementedError( | ||
| "Support for sharded checkpoints using safetensors is coming soon!" | ||
| ) | ||
| elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): | ||
| raise EnvironmentError( | ||
| f"{pretrained_model_name_or_path} does not appear to have a file named" | ||
| f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" | ||
|
|
@@ -797,6 +860,7 @@ def from_pretrained( | |
| if is_local: | ||
| logger.info(f"loading weights file {archive_file}") | ||
| resolved_archive_file = archive_file | ||
| filename = resolved_archive_file.split(os.path.sep)[-1] | ||
| else: | ||
| logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") | ||
| else: | ||
|
|
@@ -820,31 +884,27 @@ def from_pretrained( | |
| _commit_hash=commit_hash, | ||
| ) | ||
|
|
||
| safetensors_from_pt = False | ||
| if filename == SAFE_WEIGHTS_NAME: | ||
| with safe_open(resolved_archive_file, framework="flax") as f: | ||
| safetensors_metadata = f.metadata() | ||
| if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: | ||
| raise OSError( | ||
| f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." | ||
| " Make sure you save your model with the `save_pretrained` method." | ||
| ) | ||
| safetensors_from_pt = safetensors_metadata.get("format") == "pt" | ||
|
|
||
| # init random models | ||
| model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) | ||
|
|
||
| if from_pt: | ||
| if from_pt or safetensors_from_pt: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exposing my lack of knowledge about safe tensors here: if
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's correct! This way we call
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining! |
||
| state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) | ||
| else: | ||
| if is_sharded: | ||
| state = cls.load_flax_sharded_weights(resolved_archive_file) | ||
| else: | ||
| try: | ||
| with open(resolved_archive_file, "rb") as state_f: | ||
| state = from_bytes(cls, state_f.read()) | ||
| except (UnpicklingError, msgpack.exceptions.ExtraData) as e: | ||
| try: | ||
| with open(resolved_archive_file) as f: | ||
| if f.read().startswith("version"): | ||
| raise OSError( | ||
| "You seem to have cloned a repository without having git-lfs installed. Please" | ||
| " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | ||
| " folder you cloned." | ||
| ) | ||
| else: | ||
| raise ValueError from e | ||
| except (UnicodeDecodeError, ValueError): | ||
| raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") | ||
| state = cls.load_flax_weights(resolved_archive_file) | ||
| # make sure all arrays are stored as jnp.arrays | ||
| # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: | ||
| # https://github.com/google/flax/issues/1261 | ||
|
|
@@ -1029,6 +1089,7 @@ def save_pretrained( | |
| push_to_hub=False, | ||
| max_shard_size="10GB", | ||
| token: Optional[Union[str, bool]] = None, | ||
| safe_serialization: bool = False, | ||
LysandreJik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **kwargs, | ||
| ): | ||
| """ | ||
|
|
@@ -1058,6 +1119,8 @@ def save_pretrained( | |
| the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). | ||
| kwargs (`Dict[str, Any]`, *optional*): | ||
| Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | ||
| safe_serialization (`bool`, *optional*, defaults to `False`): | ||
| Whether to save the model using `safetensors` or through msgpack. | ||
| """ | ||
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
|
|
||
|
|
@@ -1101,24 +1164,31 @@ def save_pretrained( | |
| self.generation_config.save_pretrained(save_directory) | ||
|
|
||
| # save model | ||
| output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) | ||
| weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME | ||
| output_model_file = os.path.join(save_directory, weights_name) | ||
|
|
||
| shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) | ||
| # Clean the folder from a previous save | ||
| for filename in os.listdir(save_directory): | ||
| full_filename = os.path.join(save_directory, filename) | ||
| weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") | ||
| if ( | ||
| filename.startswith(FLAX_WEIGHTS_NAME[:-4]) | ||
| filename.startswith(weights_no_suffix) | ||
| and os.path.isfile(full_filename) | ||
| and filename not in shards.keys() | ||
| ): | ||
| os.remove(full_filename) | ||
|
|
||
| if index is None: | ||
| with open(output_model_file, "wb") as f: | ||
| if safe_serialization: | ||
| params = params if params is not None else self.params | ||
| model_bytes = to_bytes(params) | ||
| f.write(model_bytes) | ||
| flat_dict = flatten_dict(params, sep=".") | ||
| safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) | ||
| else: | ||
| with open(output_model_file, "wb") as f: | ||
| params = params if params is not None else self.params | ||
| model_bytes = to_bytes(params) | ||
| f.write(model_bytes) | ||
|
|
||
| else: | ||
| save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.