diff --git a/src/transformers/commands/pt_to_flax.py b/src/transformers/commands/pt_to_flax.py new file mode 100644 index 000000000000..c6d8792d87d1 --- /dev/null +++ b/src/transformers/commands/pt_to_flax.py @@ -0,0 +1,364 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from argparse import ArgumentParser, Namespace +from glob import glob +from importlib import import_module + +import numpy as np +from datasets import load_dataset +from packaging import version + +import huggingface_hub + +from .. import ( + FEATURE_EXTRACTOR_MAPPING, + PROCESSOR_MAPPING, + TOKENIZER_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + is_flax_available, + is_torch_available, +) +from ..utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, logging +from . import BaseTransformersCLICommand + + +if is_flax_available(): + import jax.numpy as jnp +if is_torch_available(): + import torch + + +MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors + + +def convert_command_factory(args: Namespace): + """ + Factory function used to convert a model PyTorch checkpoint in a Flax checkpoint. + + Returns: ServeCommand + """ + return PTtoFXCommand( + args.model_name, + args.local_dir, + args.max_hidden_error, + args.new_weights, + args.no_pr, + args.push, + args.extra_commit_description, + ) + + +class PTtoFXCommand(BaseTransformersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + """ + Register this command to argparse so it's available for the transformer-cli + + Args: + parser: Root parser to register command-specific arguments + """ + train_parser = parser.add_parser( + "pt-to-fx", + help=( + "CLI tool to run convert a transformers model from a PyTorch checkpoint to a Flax checkpoint." + " Can also be used to validate existing weights without opening PRs, with --no-pr." + ), + ) + train_parser.add_argument( + "--model-name", + type=str, + required=True, + help="The model name, including owner/organization, as seen on the hub.", + ) + train_parser.add_argument( + "--local-dir", + type=str, + default="", + help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", + ) + train_parser.add_argument( + "--max-hidden-error", + type=float, + default=MAX_ERROR, + help=( + f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden" + " layers outputs will be used for downstream applications, avoid increasing this tolerance." + ), + ) + train_parser.add_argument( + "--new-weights", + action="store_true", + help="Optional flag to create new TensorFlow weights, even if they already exist.", + ) + train_parser.add_argument( + "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." + ) + train_parser.add_argument( + "--push", + action="store_true", + help="Optional flag to push the weights directly to `main` (requires permissions)", + ) + train_parser.add_argument( + "--extra-commit-description", + type=str, + default="", + help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", + ) + train_parser.set_defaults(func=convert_command_factory) + + @staticmethod + def find_pt_fx_differences(pt_outputs, fx_outputs): + """ + Compares the Flax and PyTorch outputs, returning a dictionary with all tensor differences. + """ + # 1. All output attributes must be the same + pt_out_attrs = set(pt_outputs.keys()) + fx_out_attrs = set(fx_outputs.keys()) + if pt_out_attrs != fx_out_attrs: + raise ValueError( + f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, Flax:" + f" {fx_out_attrs})" + ) + + # 2. For each output attribute, computes the difference + def _find_pt_fx_differences(pt_out, fx_out, differences, attr_name=""): + + # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in + # recursivelly, keeping the name of the attribute. + if isinstance(pt_out, torch.Tensor): + tensor_difference = np.max(np.abs(pt_out.numpy() - fx_out.numpy())) + differences[attr_name] = tensor_difference + else: + root_name = attr_name + for i, pt_item in enumerate(pt_out): + # If it is a named attribute, we keep the name. Otherwise, just its index. + if isinstance(pt_item, str): + branch_name = root_name + pt_item + fx_item = fx_out[pt_item] + pt_item = pt_out[pt_item] + else: + branch_name = root_name + f"[{i}]" + fx_item = fx_out[i] + differences = _find_pt_fx_differences(pt_item, fx_item, differences, branch_name) + + return differences + + return _find_pt_fx_differences(pt_outputs, fx_outputs, {}) + + def __init__( + self, + model_name: str, + local_dir: str, + max_hidden_error: float, + new_weights: bool, + no_pr: bool, + push: bool, + extra_commit_description: str, + *args + ): + self._logger = logging.get_logger("transformers-cli/pt_to_fx") + self._model_name = model_name + self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) + self._max_hidden_error = max_hidden_error + self._new_weights = new_weights + self._no_pr = no_pr + self._push = push + self._extra_commit_description = extra_commit_description + + def get_inputs(self, pt_model, config): + """ + Returns the right inputs for the model, based on its signature. + """ + + def _get_audio_input(): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + speech_samples = ds.sort("id").select(range(2))[:2]["audio"] + raw_samples = [x["array"] for x in speech_samples] + return raw_samples + + model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys()) + processor_inputs = {} + if "input_ids" in model_forward_signature: + processor_inputs.update( + { + "text": ["Hi there!", "I am a batch with more than one row and different input lengths."], + "padding": True, + "truncation": True, + } + ) + if "pixel_values" in model_forward_signature: + sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"] + processor_inputs.update({"images": sample_images}) + if "input_features" in model_forward_signature: + processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) + if "input_values" in model_forward_signature: # Wav2Vec2 audio input + processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) + + model_config_class = type(pt_model.config) + if model_config_class in PROCESSOR_MAPPING: + processor = AutoProcessor.from_pretrained(self._local_dir) + if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + elif model_config_class in FEATURE_EXTRACTOR_MAPPING: + processor = AutoFeatureExtractor.from_pretrained(self._local_dir) + elif model_config_class in TOKENIZER_MAPPING: + processor = AutoTokenizer.from_pretrained(self._local_dir) + if processor.pad_token is None: + processor.pad_token = processor.eos_token + else: + raise ValueError(f"Unknown data processing type (model config type: {model_config_class})") + + pt_input = processor(**processor_inputs, return_tensors="pt") + fx_input = processor(**processor_inputs, return_tensors="np") + + # Extra input requirements, in addition to the input modality + if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")): + decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0) + pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)}) + fx_input.update({"decoder_input_ids": jnp.array(decoder_input_ids)}) + + return pt_input, fx_input + + def run(self): + if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"): + raise ImportError( + "The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub" + " installation." + ) + else: + from huggingface_hub import Repository, create_commit + from huggingface_hub._commit_api import CommitOperationAdd + + # Fetch remote data + repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) + + # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights + config = AutoConfig.from_pretrained(self._local_dir) + architectures = config.architectures + if architectures is None: # No architecture defined -- use auto classes + pt_class = getattr(import_module("transformers"), "AutoModel") + fx_class = getattr(import_module("transformers"), "FlaxAutoModel") + self._logger.warn("No detected architecture, using AutoModel/FlaxAutoModel") + else: # Architecture defined -- use it + if len(architectures) > 1: + raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})") + self._logger.warn(f"Detected architecture: {architectures[0]}") + pt_class = getattr(import_module("transformers"), architectures[0]) + try: + fx_class = getattr(import_module("transformers"), "Flax" + architectures[0]) + except AttributeError: + raise AttributeError(f"The Flax equivalent of {architectures[0]} doesn't exist in transformers.") + + # Load models and acquire a basic input compatible with the model. + pt_model = pt_class.from_pretrained(self._local_dir) + pt_model.eval() + + fx_from_pt_model = fx_class.from_pretrained( + self._local_dir, from_pt=True + ) # now also works for sharded checkpoints + pt_input, fx_input = self.get_inputs(pt_model, config) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_input, output_hidden_states=True) + del pt_model # will no longer be used, and may have a large memory footprint + + fx_from_pt_model = fx_class.from_pretrained(self._local_dir, from_pt=True) + fx_from_pt_outputs = fx_from_pt_model(**fx_input, output_hidden_states=True) + + # Confirms that cross loading PT weights into FX worked. + crossload_differences = self.find_pt_fx_differences(pt_outputs, fx_from_pt_outputs) + output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k} + max_crossload_output_diff = max(output_differences.values()) + max_crossload_hidden_diff = max(hidden_differences.values()) + if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error: + raise ValueError( + "The cross-loaded TensorFlow model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) + ) + + # Save the weights in a FX format (if needed) and confirms that the results are still good + fx_weights_path = os.path.join(self._local_dir, FLAX_WEIGHTS_NAME) + fx_weights_index_path = os.path.join(self._local_dir, FLAX_WEIGHTS_INDEX_NAME) + if (not os.path.exists(fx_weights_path) and not os.path.exists(fx_weights_index_path)) or self._new_weights: + fx_from_pt_model.save_pretrained(self._local_dir) + del fx_from_pt_model # will no longer be used, and may have a large memory footprint + + fx_model = fx_class.from_pretrained(self._local_dir) + fx_outputs = fx_model(**fx_input, output_hidden_states=True) + + conversion_differences = self.find_pt_fx_differences(pt_outputs, fx_outputs) + output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k} + max_conversion_output_diff = max(output_differences.values()) + max_conversion_hidden_diff = max(hidden_differences.values()) + if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error: + raise ValueError( + "The converted Flax model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) + ) + + commit_message = "Update FX weights" if self._new_weights else "Add FX weights" + if self._push: + repo.git_add(auto_lfs_track=True) + repo.git_commit(commit_message) + repo.git_push(blocking=True) # this prints a progress bar with the upload + self._logger.warn(f"FX weights pushed into {self._model_name}") + elif not self._no_pr: + self._logger.warn("Uploading the weights into a new PR...") + commit_descrition = ( + "Model converted by the [`transformers`' `pt_to_fx`" + " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_fx.py). " + "All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n" + f"Maximum crossload output difference={max_crossload_output_diff:.3e}; " + f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n" + f"Maximum conversion output difference={max_conversion_output_diff:.3e}; " + f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n" + ) + if self._extra_commit_description: + commit_descrition += "\n\n" + self._extra_commit_description + + # sharded model -> adds all related files (index and .h5 shards) + if os.path.exists(fx_weights_index_path): + operations = [ + CommitOperationAdd(path_in_repo=FLAX_WEIGHTS_INDEX_NAME, path_or_fileobj=fx_weights_index_path) + ] + for shard_path in glob.glob(self._local_dir + "/flax_model-*.msgpack"): + operations += [ + CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path) + ] + else: + operations = [CommitOperationAdd(path_in_repo=FLAX_WEIGHTS_NAME, path_or_fileobj=fx_weights_path)] + + hub_pr_url = create_commit( + repo_id=self._model_name, + operations=operations, + commit_message=commit_message, + commit_description=commit_descrition, + repo_type="model", + create_pr=True, + ) + self._logger.warn(f"PR open in {hub_pr_url}")