Skip to content
9 changes: 9 additions & 0 deletions examples/glue/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# GLUE Benchmark

Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py).

#### Run PyTorch version using PyTorch-Lightning

Run `bash run_pl.sh` from the `glue` directory. This will also install `pytorch-lightning` and the requirements in `examples/requirements.txt`. It is a shell pipeline that will automatically download, pre-process the data and run the specified models. Logs are saved in `lightning_logs` directory.

Pass `--n_gpu` flag to change the number of GPUs. Default uses 1. At the end, the expected results are: `TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recall': 0.869537067011978, 'f1': 0.8608974358974358}`
38 changes: 38 additions & 0 deletions examples/glue/run_pl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Install newest ptl.
pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/
# Install example requirements
pip install -r ../requirements.txt

# Download glue data
python3 ../../utils/download_glue_data.py

export TASK=mrpc
export DATA_DIR=./glue_data/MRPC/
export MAX_LENGTH=128
export LEARNING_RATE=2e-5
export BERT_MODEL=bert-base-cased
export MODEL_TYPE=bert
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SEED=2
export OUTPUT_DIR_NAME=mrpc-pl-bert
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}

# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

python3 run_pl_glue.py --data_dir $DATA_DIR \
--model_type $MODEL_TYPE \
--task $TASK \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--learning_rate $LEARNING_RATE \
--num_train_epochs $NUM_EPOCHS \
--train_batch_size $BATCH_SIZE \
--seed $SEED \
--do_train \
--do_predict
196 changes: 196 additions & 0 deletions examples/glue/run_pl_glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import argparse
import glob
import logging
import os
import time

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from transformer_base import BaseTransformer, add_generic_args, generic_train
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import glue_output_modes
from transformers import glue_processors as processors
from transformers import glue_tasks_num_labels


logger = logging.getLogger(__name__)


class GLUETransformer(BaseTransformer):

mode = "sequence-classification"

def __init__(self, hparams):
hparams.glue_output_mode = glue_output_modes[hparams.task]
num_labels = glue_tasks_num_labels[hparams.task]

super().__init__(hparams, num_labels, self.mode)

def forward(self, **inputs):
return self.model(**inputs)

def training_step(self, batch, batch_idx):
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}

if self.hparams.model_type != "distilbert":
inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None

outputs = self(**inputs)
loss = outputs[0]

tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss, "log": tensorboard_logs}

def prepare_data(self):
"Called to initialize data. Use the call to construct features"
args = self.hparams
processor = processors[args.task]()
self.labels = processor.get_labels()

for mode in ["train", "dev"]:
cached_features_file = self._feature_file(mode)
if not os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Creating features from dataset file at %s", args.data_dir)
examples = (
processor.get_dev_examples(args.data_dir)
if mode == "dev"
else processor.get_train_examples(args.data_dir)
)
features = convert_examples_to_features(
examples,
self.tokenizer,
max_length=args.max_seq_length,
task=args.task,
label_list=self.labels,
output_mode=args.glue_output_mode,
pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
pad_token=self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

def load_dataset(self, mode, batch_size):
"Load datasets. Called after prepare data."

# We test on dev set to compare to benchmarks without having to submit to GLUE server
mode = "dev" if mode == "test" else mode

cached_features_file = self._feature_file(mode)
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
if self.hparams.glue_output_mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif self.hparams.glue_output_mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

return DataLoader(
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels),
batch_size=batch_size,
shuffle=True,
)

def validation_step(self, batch, batch_idx):
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}

if self.hparams.model_type != "distilbert":
inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None

outputs = self(**inputs)
tmp_eval_loss, logits = outputs[:2]
preds = logits.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()

return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}

def _eval_end(self, outputs):
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
preds = np.concatenate([x["pred"] for x in outputs], axis=0)

if self.hparams.glue_output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif self.hparams.glue_output_mode == "regression":
preds = np.squeeze(preds)

out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]

results = {**{"val_loss": val_loss_mean}, **compute_metrics(self.hparams.task, preds, out_label_ids)}

ret = {k: v for k, v in results.items()}
ret["log"] = results
return ret, preds_list, out_label_list

def validation_end(self, outputs: list) -> dict:
ret, preds, targets = self._eval_end(outputs)
logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}

def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
ret, predictions, targets = self._eval_end(outputs)

# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}

@staticmethod
def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)

parser.add_argument(
"--task", default="", type=str, required=True, help="The GLUE task to run",
)

parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)

parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)

return parser


if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = GLUETransformer.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()

# If output_dir not provided, a folder will be generated in pwd
if args.output_dir is None:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)

model = GLUETransformer(args)
trainer = generic_train(model, args)

# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
GLUETransformer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
5 changes: 4 additions & 1 deletion examples/ner/run_pl.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
mkdir -p $OUTPUT_DIR

# Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

python3 run_pl_ner.py --data_dir ./ \
--model_type bert \
--labels ./labels.txt \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--train_batch_size 32 \
--train_batch_size $BATCH_SIZE \
--seed $SEED \
--do_train \
--do_predict
55 changes: 8 additions & 47 deletions examples/ner/run_pl_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ class NERTransformer(BaseTransformer):
A training module for NER. See BaseTransformer for the core options.
"""

mode = "token-classification"

def __init__(self, hparams):
self.labels = get_labels(hparams.labels)
num_labels = len(self.labels)
self.pad_token_label_id = CrossEntropyLoss().ignore_index
super(NERTransformer, self).__init__(hparams, num_labels)
super(NERTransformer, self).__init__(hparams, num_labels, self.mode)

def forward(self, **inputs):
return self.model(**inputs)
Expand All @@ -38,21 +40,11 @@ def training_step(self, batch, batch_num):
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
) # XLM and RoBERTa don"t use segment_ids

outputs = self.forward(**inputs)
outputs = self(**inputs)
loss = outputs[0]
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss, "log": tensorboard_logs}

def _feature_file(self, mode):
return os.path.join(
self.hparams.data_dir,
"cached_{}_{}_{}".format(
mode,
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
str(self.hparams.max_seq_length),
),
)

def prepare_data(self):
"Called to initialize data. Use the call to construct features"
args = self.hparams
Expand Down Expand Up @@ -100,7 +92,7 @@ def validation_step(self, batch, batch_nb):
inputs["token_type_ids"] = (
batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None
) # XLM and RoBERTa don"t use segment_ids
outputs = self.forward(**inputs)
outputs = self(**inputs)
tmp_eval_loss, logits = outputs[:2]
preds = logits.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()
Expand Down Expand Up @@ -130,14 +122,8 @@ def _eval_end(self, outputs):
"f1": f1_score(out_label_list, preds_list),
}

if self.is_logger():
logger.info("***** Eval results *****")
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))

tensorboard_logs = results
ret = {k: v for k, v in results.items()}
ret["log"] = tensorboard_logs
ret["log"] = results
return ret, preds_list, out_label_list

def validation_end(self, outputs):
Expand All @@ -151,32 +137,7 @@ def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
ret, predictions, targets = self._eval_end(outputs)

if self.is_logger():
# Write output to a file:
# Save results
output_test_results_file = os.path.join(self.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(ret.keys()):
if key != "log":
writer.write("{} = {}\n".format(key, str(ret[key])))
# Save predictions
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(self.hparams.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not predictions[example_id]:
example_id += 1
elif predictions[example_id]:
output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
# Converting to the dic required by pl
# Converting to the dict required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
Expand Down Expand Up @@ -230,6 +191,6 @@ def add_model_specific_args(parser, root_dir):
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpointepoch=*.ckpt", recursive=True)))
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
NERTransformer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
Loading