Skip to content
7 changes: 7 additions & 0 deletions examples/ner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a
| `roberta-large` | 95.96 | 91.87
| `distilbert-base-uncased` | 94.34 | 90.32

#### Run PyTorch version using PyTorch-Lightning

Run `bash run_pl.sh` from the `ner` directory. This would also install `pytorch-lightning` and the `examples/requirements.txt`. It is a shell pipeline which would automatically download, pre-process the data and run the models in `germeval-model` directory. Logs are saved in `lightning_logs` directory.

Pass `--n_gpu` flag to change the number of GPUs. Default uses 1. At the end, the expected results are: `TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recall': 0.869537067011978, 'f1': 0.8608974358974358}`


### Run the Tensorflow 2 version

To start training, just run:
Expand Down
15 changes: 10 additions & 5 deletions examples/ner/run_pl.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env bash

# Install newest ptl.
pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/

# for seqeval metrics import
pip install -r ../requirements.txt

curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
Expand All @@ -15,12 +18,15 @@ python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
export OUTPUT_DIR=germeval-model
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=1

export OUTPUT_DIR_NAME=germeval-model
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
mkdir -p $OUTPUT_DIR

python3 run_pl_ner.py --data_dir ./ \
--model_type bert \
--labels ./labels.txt \
Expand All @@ -29,7 +35,6 @@ python3 run_pl_ner.py --data_dir ./ \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--train_batch_size 32 \
--save_steps $SAVE_STEPS \
--seed $SEED \
--do_train \
--do_predict
--do_predict
21 changes: 17 additions & 4 deletions examples/ner/run_pl_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ def _eval_end(self, outputs):
return ret, preds_list, out_label_list

def validation_end(self, outputs):
# todo: update to validation_epoch_end instead of deprecated validation_end
# when stable
ret, preds, targets = self._eval_end(outputs)
return ret
logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}

def test_end(self, outputs):
def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
ret, predictions, targets = self._eval_end(outputs)

if self.is_logger():
Expand Down Expand Up @@ -172,7 +176,12 @@ def test_end(self, outputs):
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
return ret
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}

@staticmethod
def add_model_specific_args(parser, root_dir):
Expand Down Expand Up @@ -217,6 +226,10 @@ def add_model_specific_args(parser, root_dir):
trainer = generic_train(model, args)

if args.do_predict:
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpoint_*.ckpt", recursive=True)))
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpointepoch=*.ckpt", recursive=True)))
NERTransformer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)