From 32a34d26cc993a3df61d480932c6a90918be002e Mon Sep 17 00:00:00 2001 From: Jin Young Sohn Date: Wed, 1 Apr 2020 00:47:30 +0000 Subject: [PATCH] Use barriers to reduce duplicate work/resources --- examples/run_glue_tpu.py | 79 +++++++++++++++++++---------- transformers/configuration_utils.py | 4 +- transformers/file_utils.py | 15 ++---- transformers/modeling_utils.py | 16 +++--- transformers/tokenization_utils.py | 5 +- 5 files changed, 68 insertions(+), 51 deletions(-) diff --git a/examples/run_glue_tpu.py b/examples/run_glue_tpu.py index 18e5c0a00b1b..ff99d1f7af21 100644 --- a/examples/run_glue_tpu.py +++ b/examples/run_glue_tpu.py @@ -135,14 +135,21 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False): train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging) for step, batch in enumerate(epoch_iterator): + # Save model checkpoint. if args.save_steps > 0 and global_step % args.save_steps == 0: - output_dir = os.path.join(args.output_dir, 'checkpoint-{}-xla{}'.format(global_step, xm.get_ordinal())) + output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) logger.info("Saving model checkpoint to %s", output_dir) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + + if xm.is_master_ordinal(): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + + # Barrier to wait for saving checkpoint. + xm.rendezvous('mid_training_checkpoint') + # model.save_pretrained needs to be called by all ordinals model.save_pretrained(output_dir, xla_device=True) - torch.save(args, os.path.join(output_dir, 'training_args.bin')) model.train() inputs = {'input_ids': batch[0], @@ -263,15 +270,19 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False): def load_and_cache_examples(args, task, tokenizer, evaluate=False): + if not xm.is_master_ordinal(): + xm.rendezvous('load_and_cache_examples') + processor = processors[task]() output_mode = output_modes[task] + cached_features_file = os.path.join( + args.data_dir, 'cached_{}_{}_{}_{}'.format( + 'dev' if evaluate else 'train', + list(filter(None, args.model_name_or_path.split('/'))).pop(), + str(args.max_seq_length), + str(task))) + # Load data features from cache or dataset file - cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_xla{}'.format( - 'dev' if evaluate else 'train', - list(filter(None, args.model_name_or_path.split('/'))).pop(), - str(args.max_seq_length), - str(task), - xm.get_ordinal())) if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) @@ -294,6 +305,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) + if xm.is_master_ordinal(): + xm.rendezvous('load_and_cache_examples') + # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) @@ -341,6 +355,9 @@ def main(args): label_list = processor.get_labels() num_labels = len(label_list) + if not xm.is_master_ordinal(): + xm.rendezvous('download_only_once') # Make sure only the first process in distributed training will download model & vocab + # Load pretrained model and tokenizer args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] @@ -359,47 +376,53 @@ def main(args): cache_dir=args.cache_dir if args.cache_dir else None, xla_device=True) + if xm.is_master_ordinal(): + xm.rendezvous('download_only_once') + # Send model to TPU/XLA device. model.to(args.device) logger.info("Training/evaluation parameters %s", args) - output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal())) if args.do_train: # Train the model. train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - # Save trained model. - # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() - # Create output directory if needed - if not os.path.exists(output_dir): - os.makedirs(output_dir) + if xm.is_master_ordinal(): + # Save trained model. + # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() + + # Create output directory if needed + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) - logger.info("Saving model checkpoint to %s", output_dir) - # Save a trained model, configuration and tokenizer using `save_pretrained()`. - # They can then be reloaded using `from_pretrained()` - model.save_pretrained(output_dir, xla_device=True) - tokenizer.save_pretrained(output_dir) + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + tokenizer.save_pretrained(args.output_dir) + # Good practice: save your training arguments together with the trained. + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) - # Good practice: save your training arguments together with the trained. - torch.save(args, os.path.join(output_dir, 'training_args.bin')) + xm.rendezvous('post_training_checkpoint') + # model.save_pretrained needs to be called by all ordinals + model.save_pretrained(args.output_dir, xla_device=True) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(output_dir, xla_device=True) - tokenizer = tokenizer_class.from_pretrained(output_dir, xla_device=True) + model = model_class.from_pretrained(args.output_dir, xla_device=True) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, xla_device=True) model.to(args.device) # Evaluation results = {} if args.do_eval: - tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=args.do_lower_case, xla_device=True) - checkpoints = [output_dir] + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, xla_device=True) + checkpoints = [args.output_dir] if args.eval_all_checkpoints: - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 34b3071c311b..547bb69c5e57 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -122,7 +122,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) - xla_device = kwargs.pop('xla_device', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] @@ -132,8 +131,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, xla_device=xla_device) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( diff --git a/transformers/file_utils.py b/transformers/file_utils.py index e6fa1aa15358..8c64a8585c3d 100644 --- a/transformers/file_utils.py +++ b/transformers/file_utils.py @@ -102,7 +102,7 @@ def docstring_decorator(fn): return fn return docstring_decorator -def url_to_filename(url, etag=None, xla_device=False): +def url_to_filename(url, etag=None): """ Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, delimited @@ -120,10 +120,6 @@ def url_to_filename(url, etag=None, xla_device=False): etag_hash = sha256(etag_bytes) filename += '.' + etag_hash.hexdigest() - if xla_device: - import torch_xla.core.xla_model as xm - filename += '.xla' + str(xm.get_ordinal()) - if url.endswith('.h5'): filename += '.h5' @@ -156,7 +152,7 @@ def filename_to_url(filename, cache_dir=None): return url, etag -def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, xla_device=False): +def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and @@ -177,8 +173,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N if parsed.scheme in ('http', 'https', 's3'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, xla_device=xla_device) + return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -252,7 +247,7 @@ def http_get(url, temp_file, proxies=None): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, - etag_timeout=10, xla_device=False): + etag_timeout=10): """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. @@ -282,7 +277,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, if sys.version_info[0] == 2 and etag is not None: etag = etag.decode('utf-8') - filename = url_to_filename(url, etag, xla_device=xla_device) + filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 1ffd8e48a64a..d08fa5cab2d3 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -245,17 +245,20 @@ def save_pretrained(self, save_directory, xla_device=False): # Only save the model itself if we are using distributed training model_to_save = self.module if hasattr(self, 'module') else self - # Save configuration file - model_to_save.config.save_pretrained(save_directory) - # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + + # Save configuration file if xla_device: - # Saving for each process is save since output_dir is proc# namespaced. import torch_xla.core.xla_model as xm - xm.save(model_to_save.state_dict(), output_model_file, master_only=False) + if xm.is_master_ordinal(): + model_to_save.config.save_pretrained(save_directory) + # xm.save takes care of saving only from master + xm.save(model_to_save.state_dict(), output_model_file) else: + model_to_save.config.save_pretrained(save_directory) torch.save(model_to_save.state_dict(), output_model_file) + logger.info("Model weights saved in {}".format(output_model_file)) @classmethod @@ -371,8 +374,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, xla_device=xla_device) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index b8037a4e3f5c..67d920942009 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -287,7 +287,6 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs) cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) - xla_device = kwargs.pop('xla_device', False) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -343,7 +342,7 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs) "We assumed '{}' was a path or url to a directory containing vocabulary files " "named {} but couldn't find such vocabulary files at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), - pretrained_model_name_or_path, + pretrained_model_name_or_path, list(cls.vocab_files_names.values()))) # Get files from url, cache, or disk depending on the case @@ -354,7 +353,7 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs) resolved_vocab_files[file_id] = None else: resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, - force_download=force_download, proxies=proxies, xla_device=xla_device) + force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in s3_models: msg = "Couldn't reach server at '{}' to download vocabulary files."