-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[WIP] Lightning glue example #3290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
4345814
:sparkles: Alter base pl transformer to use automodels
nateraw fbf147b
:bug: Add batch size env variable to function call
nateraw 888e192
:lipstick: Apply black code style from Makefile
nateraw d25beb7
:truck: Move lightning base out of ner directory
nateraw 59e6458
:sparkles: Add lightning glue example
nateraw 43173a8
:lipstick: self
nateraw 326a819
move _feature_file to base class
nateraw 82546a7
:sparkles: Move eval logging to custom callback
nateraw 7854db7
:lipstick: Apply black code style
nateraw dd1b783
:bug: Add parent to pythonpath, remove copy command
nateraw 9361267
:bug: Add missing max_length kwarg
nateraw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # GLUE Benchmark | ||
|
|
||
| Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py). | ||
|
|
||
| #### Run PyTorch version using PyTorch-Lightning | ||
|
|
||
| Run `bash run_pl.sh` from the `glue` directory. This will also install `pytorch-lightning` and the requirements in `examples/requirements.txt`. It is a shell pipeline that will automatically download, pre-process the data and run the specified models. Logs are saved in `lightning_logs` directory. | ||
|
|
||
| Pass `--n_gpu` flag to change the number of GPUs. Default uses 1. At the end, the expected results are: `TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recall': 0.869537067011978, 'f1': 0.8608974358974358}` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Install newest ptl. | ||
| pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/ | ||
| # Install example requirements | ||
| pip install -r ../requirements.txt | ||
|
|
||
| # Download glue data | ||
| python3 ../../utils/download_glue_data.py | ||
|
|
||
| export TASK=mrpc | ||
| export DATA_DIR=./glue_data/MRPC/ | ||
| export MAX_LENGTH=128 | ||
| export LEARNING_RATE=2e-5 | ||
| export BERT_MODEL=bert-base-cased | ||
| export MODEL_TYPE=bert | ||
| export BATCH_SIZE=32 | ||
| export NUM_EPOCHS=3 | ||
| export SEED=2 | ||
| export OUTPUT_DIR_NAME=mrpc-pl-bert | ||
| export CURRENT_DIR=${PWD} | ||
| export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} | ||
|
|
||
| # Make output directory if it doesn't exist | ||
| mkdir -p $OUTPUT_DIR | ||
| # Add parent directory to python path to access transformer_base.py | ||
| export PYTHONPATH="../":"${PYTHONPATH}" | ||
|
|
||
| python3 run_pl_glue.py --data_dir $DATA_DIR \ | ||
| --model_type $MODEL_TYPE \ | ||
| --task $TASK \ | ||
| --model_name_or_path $BERT_MODEL \ | ||
| --output_dir $OUTPUT_DIR \ | ||
| --max_seq_length $MAX_LENGTH \ | ||
| --learning_rate $LEARNING_RATE \ | ||
| --num_train_epochs $NUM_EPOCHS \ | ||
| --train_batch_size $BATCH_SIZE \ | ||
| --seed $SEED \ | ||
| --do_train \ | ||
| --do_predict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,196 @@ | ||
| import argparse | ||
| import glob | ||
| import logging | ||
| import os | ||
| import time | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch.utils.data import DataLoader, TensorDataset | ||
|
|
||
| from transformer_base import BaseTransformer, add_generic_args, generic_train | ||
| from transformers import glue_compute_metrics as compute_metrics | ||
| from transformers import glue_convert_examples_to_features as convert_examples_to_features | ||
| from transformers import glue_output_modes | ||
| from transformers import glue_processors as processors | ||
| from transformers import glue_tasks_num_labels | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class GLUETransformer(BaseTransformer): | ||
|
|
||
| mode = "sequence-classification" | ||
|
|
||
| def __init__(self, hparams): | ||
| hparams.glue_output_mode = glue_output_modes[hparams.task] | ||
| num_labels = glue_tasks_num_labels[hparams.task] | ||
|
|
||
| super().__init__(hparams, num_labels, self.mode) | ||
|
|
||
| def forward(self, **inputs): | ||
| return self.model(**inputs) | ||
|
|
||
| def training_step(self, batch, batch_idx): | ||
| inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} | ||
|
|
||
| if self.hparams.model_type != "distilbert": | ||
srush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None | ||
|
|
||
| outputs = self(**inputs) | ||
| loss = outputs[0] | ||
|
|
||
| tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]} | ||
| return {"loss": loss, "log": tensorboard_logs} | ||
|
|
||
| def prepare_data(self): | ||
| "Called to initialize data. Use the call to construct features" | ||
| args = self.hparams | ||
| processor = processors[args.task]() | ||
| self.labels = processor.get_labels() | ||
|
|
||
| for mode in ["train", "dev"]: | ||
| cached_features_file = self._feature_file(mode) | ||
| if not os.path.exists(cached_features_file) and not args.overwrite_cache: | ||
| logger.info("Creating features from dataset file at %s", args.data_dir) | ||
| examples = ( | ||
| processor.get_dev_examples(args.data_dir) | ||
| if mode == "dev" | ||
| else processor.get_train_examples(args.data_dir) | ||
| ) | ||
| features = convert_examples_to_features( | ||
| examples, | ||
| self.tokenizer, | ||
| max_length=args.max_seq_length, | ||
| task=args.task, | ||
| label_list=self.labels, | ||
| output_mode=args.glue_output_mode, | ||
| pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet | ||
| pad_token=self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0], | ||
| pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, | ||
| ) | ||
| logger.info("Saving features into cached file %s", cached_features_file) | ||
| torch.save(features, cached_features_file) | ||
|
|
||
| def load_dataset(self, mode, batch_size): | ||
| "Load datasets. Called after prepare data." | ||
|
|
||
| # We test on dev set to compare to benchmarks without having to submit to GLUE server | ||
| mode = "dev" if mode == "test" else mode | ||
|
|
||
| cached_features_file = self._feature_file(mode) | ||
| logger.info("Loading features from cached file %s", cached_features_file) | ||
| features = torch.load(cached_features_file) | ||
| all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) | ||
| all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) | ||
| all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) | ||
| if self.hparams.glue_output_mode == "classification": | ||
| all_labels = torch.tensor([f.label for f in features], dtype=torch.long) | ||
| elif self.hparams.glue_output_mode == "regression": | ||
| all_labels = torch.tensor([f.label for f in features], dtype=torch.float) | ||
|
|
||
| return DataLoader( | ||
| TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels), | ||
| batch_size=batch_size, | ||
| shuffle=True, | ||
| ) | ||
|
|
||
| def validation_step(self, batch, batch_idx): | ||
| inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} | ||
|
|
||
| if self.hparams.model_type != "distilbert": | ||
| inputs["token_type_ids"] = batch[2] if self.hparams.model_type in ["bert", "xlnet", "albert"] else None | ||
|
|
||
| outputs = self(**inputs) | ||
| tmp_eval_loss, logits = outputs[:2] | ||
| preds = logits.detach().cpu().numpy() | ||
| out_label_ids = inputs["labels"].detach().cpu().numpy() | ||
|
|
||
| return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} | ||
|
|
||
| def _eval_end(self, outputs): | ||
| val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item() | ||
| preds = np.concatenate([x["pred"] for x in outputs], axis=0) | ||
|
|
||
| if self.hparams.glue_output_mode == "classification": | ||
| preds = np.argmax(preds, axis=1) | ||
| elif self.hparams.glue_output_mode == "regression": | ||
| preds = np.squeeze(preds) | ||
|
|
||
| out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0) | ||
| out_label_list = [[] for _ in range(out_label_ids.shape[0])] | ||
| preds_list = [[] for _ in range(out_label_ids.shape[0])] | ||
|
|
||
| results = {**{"val_loss": val_loss_mean}, **compute_metrics(self.hparams.task, preds, out_label_ids)} | ||
|
|
||
| ret = {k: v for k, v in results.items()} | ||
| ret["log"] = results | ||
| return ret, preds_list, out_label_list | ||
|
|
||
| def validation_end(self, outputs: list) -> dict: | ||
| ret, preds, targets = self._eval_end(outputs) | ||
| logs = ret["log"] | ||
| return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} | ||
|
|
||
| def test_epoch_end(self, outputs): | ||
| # updating to test_epoch_end instead of deprecated test_end | ||
| ret, predictions, targets = self._eval_end(outputs) | ||
|
|
||
| # Converting to the dic required by pl | ||
| # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\ | ||
| # pytorch_lightning/trainer/logging.py#L139 | ||
| logs = ret["log"] | ||
| # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss` | ||
| return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs} | ||
|
|
||
| @staticmethod | ||
| def add_model_specific_args(parser, root_dir): | ||
| # Add NER specific options | ||
| BaseTransformer.add_model_specific_args(parser, root_dir) | ||
| parser.add_argument( | ||
| "--max_seq_length", | ||
| default=128, | ||
| type=int, | ||
| help="The maximum total input sequence length after tokenization. Sequences longer " | ||
| "than this will be truncated, sequences shorter will be padded.", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--task", default="", type=str, required=True, help="The GLUE task to run", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--data_dir", | ||
| default=None, | ||
| type=str, | ||
| required=True, | ||
| help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" | ||
| ) | ||
|
|
||
| return parser | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| add_generic_args(parser, os.getcwd()) | ||
| parser = GLUETransformer.add_model_specific_args(parser, os.getcwd()) | ||
| args = parser.parse_args() | ||
|
|
||
| # If output_dir not provided, a folder will be generated in pwd | ||
| if args.output_dir is None: | ||
| args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) | ||
| os.makedirs(args.output_dir) | ||
|
|
||
| model = GLUETransformer(args) | ||
| trainer = generic_train(model, args) | ||
srush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Optionally, predict on dev set and write to output_dir | ||
| if args.do_predict: | ||
| checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) | ||
| GLUETransformer.load_from_checkpoint(checkpoints[-1]) | ||
| trainer.test(model) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.