diff --git a/examples/run_glue_tpu.py b/examples/run_glue_tpu.py index ff99d1f7af21..e40079b2f677 100644 --- a/examples/run_glue_tpu.py +++ b/examples/run_glue_tpu.py @@ -18,13 +18,14 @@ from __future__ import absolute_import, division, print_function import argparse +from collections import defaultdict import glob import logging import math import multiprocessing import os -import time import random +import time import numpy as np import torch @@ -205,18 +206,18 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) - output_dir = '{}/eval-xla{}'.format(args.output_dir, xm.get_ordinal()) - eval_outputs_dirs = (output_dir, output_dir + '-MM') if args.task_name == "mnli" else (output_dir,) + eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) + eval_sampler = get_sampler(eval_dataset) if not os.path.exists(eval_output_dir): os.makedirs(eval_output_dir) # Note that we don't shard for TPU Multiprocess as we don't reduce loss among client processes. - dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False) + dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False) eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) # Eval! @@ -238,9 +239,9 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False): # XLM, DistilBERT and RoBERTa don't use segment_ids inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None outputs = model(**inputs) - tmp_eval_loss, logits = outputs[:2] + batch_eval_loss, logits = outputs[:2] - eval_loss += tmp_eval_loss.mean().item() + eval_loss += batch_eval_loss nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() @@ -249,6 +250,11 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False): preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) + # Get all predictions and labels from all workers + preds = xm.mesh_reduce('eval_preds', preds, np.concatenate) + out_label_ids = xm.mesh_reduce( + 'eval_out_label_ids', out_label_ids, np.concatenate) + eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) @@ -256,14 +262,19 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False): preds = np.squeeze(preds) result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) + results['eval_loss'] = eval_loss.item() output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results {} *****".format(prefix)) - for key in sorted(result.keys()): - logger.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) - tb_writer.add_scalar(key, result[key]) + if xm.is_master_ordinal(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(results.keys()): + logger.info(" %s = %s", key, str(results[key])) + writer.write("%s = %s\n" % (key, str(results[key]))) + tb_writer.add_scalar(key, results[key]) + + if args.metrics_debug: + xm.master_print(met.metrics_report()) tb_writer.close() return results