Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@

import transformers

from . import is_safetensors_available
from .utils import logging


if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file


logger = logging.get_logger(__name__)


Expand All @@ -56,7 +62,13 @@ def load_pytorch_checkpoint_in_flax_state_dict(
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")

pt_state_dict = torch.load(pt_path, map_location="cpu")
if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
Expand Down Expand Up @@ -319,11 +331,15 @@ def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)

# load flax weight dict
with open(flax_checkpoint_path, "rb") as state_f:
try:
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
if flax_checkpoint_path.endswith(".safetensors"):
flax_state_dict = safe_load_file(flax_checkpoint_path)
flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
else:
with open(flax_checkpoint_path, "rb") as state_f:
try:
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")

return load_flax_weights_in_pytorch_model(model, flax_state_dict)

Expand Down
126 changes: 98 additions & 28 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushToHubMixin,
Expand All @@ -54,8 +56,14 @@
replace_return_docstrings,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import is_safetensors_available


if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
from safetensors.flax import save_file as safe_save_file

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -422,6 +430,31 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

@classmethod
def load_flax_weights(cls, resolved_archive_file):
try:
if resolved_archive_file.endswith(".safetensors"):
state = safe_load_file(resolved_archive_file)
state = unflatten_dict(state, sep=".")
else:
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")

return state

@classmethod
def load_flax_sharded_weights(cls, shard_files):
"""
Expand Down Expand Up @@ -687,7 +720,12 @@ def from_pretrained(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
if is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif from_pt and os.path.isfile(
Expand All @@ -704,6 +742,13 @@ def from_pretrained(
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
Expand All @@ -722,7 +767,13 @@ def from_pretrained(
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else:
filename = FLAX_WEIGHTS_NAME

try:
# Load from URL or cache if already cached
cached_file_kwargs = {
Expand All @@ -740,8 +791,15 @@ def from_pretrained(
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Did not find the safetensors file, let's fallback to Flax.
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = FLAX_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
Expand All @@ -750,21 +808,26 @@ def from_pretrained(
if resolved_archive_file is not None:
is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
elif resolved_archive_file is None and from_pt:
if resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"token": token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
Expand Down Expand Up @@ -797,6 +860,7 @@ def from_pretrained(
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
filename = resolved_archive_file.split(os.path.sep)[-1]
else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
Expand All @@ -820,31 +884,27 @@ def from_pretrained(
_commit_hash=commit_hash,
)

safetensors_from_pt = False
if filename == SAFE_WEIGHTS_NAME:
with safe_open(resolved_archive_file, framework="flax") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
)
safetensors_from_pt = safetensors_metadata.get("format") == "pt"

# init random models
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)

if from_pt:
if from_pt or safetensors_from_pt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposing my lack of knowledge about safe tensors here: if safetensors_from_pt is True here then is the reason we do load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) because the serialized weights are in "pytorch format" and therefore can't be loaded using cls.load_flax_weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct! This way we call load_pytorch_checkpoint_in_flax_state_dict with the safetensors file, and that method checks if the file ends with .safetensors to load it the pytorch-way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
else:
if is_sharded:
state = cls.load_flax_sharded_weights(resolved_archive_file)
else:
try:
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
state = cls.load_flax_weights(resolved_archive_file)
# make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
Expand Down Expand Up @@ -1029,6 +1089,7 @@ def save_pretrained(
push_to_hub=False,
max_shard_size="10GB",
token: Optional[Union[str, bool]] = None,
safe_serialization: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -1058,6 +1119,8 @@ def save_pretrained(
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or through msgpack.
"""
use_auth_token = kwargs.pop("use_auth_token", None)

Expand Down Expand Up @@ -1101,24 +1164,31 @@ def save_pretrained(
self.generation_config.save_pretrained(save_directory)

# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
output_model_file = os.path.join(save_directory, weights_name)

shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if (
filename.startswith(FLAX_WEIGHTS_NAME[:-4])
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shards.keys()
):
os.remove(full_filename)

if index is None:
with open(output_model_file, "wb") as f:
if safe_serialization:
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)
flat_dict = flatten_dict(params, sep=".")
safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
else:
with open(output_model_file, "wb") as f:
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)

else:
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
Expand Down
Loading