Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import absolute_import, division, print_function

import argparse
from collections import defaultdict
import glob
import logging
import math
import multiprocessing
import os
import time
import random
import time

import numpy as np
import torch
Expand Down Expand Up @@ -205,18 +206,18 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):

# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
output_dir = '{}/eval-xla{}'.format(args.output_dir, xm.get_ordinal())
eval_outputs_dirs = (output_dir, output_dir + '-MM') if args.task_name == "mnli" else (output_dir,)
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)

results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
eval_sampler = get_sampler(eval_dataset)

if not os.path.exists(eval_output_dir):
os.makedirs(eval_output_dir)

# Note that we don't shard for TPU Multiprocess as we don't reduce loss among client processes.
dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False)
dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False)
eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

# Eval!
Expand All @@ -238,9 +239,9 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
# XLM, DistilBERT and RoBERTa don't use segment_ids
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
batch_eval_loss, logits = outputs[:2]

eval_loss += tmp_eval_loss.mean().item()
eval_loss += batch_eval_loss
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
Expand All @@ -249,21 +250,31 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

# Get all predictions and labels from all workers
preds = xm.mesh_reduce('eval_preds', preds, np.concatenate)
out_label_ids = xm.mesh_reduce(
'eval_out_label_ids', out_label_ids, np.concatenate)

eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
results['eval_loss'] = eval_loss.item()

output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
tb_writer.add_scalar(key, result[key])
if xm.is_master_ordinal():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is everything being logged here already on cpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's add a comment? It's a subtle point that can be missed by code readers.

with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
writer.write("%s = %s\n" % (key, str(results[key])))
tb_writer.add_scalar(key, results[key])

if args.metrics_debug:
xm.master_print(met.metrics_report())

tb_writer.close()
return results
Expand Down