Skip to content

Commit 837fac2

Browse files
committed
Add README section for TPU and address comments.
1 parent 5a44823 commit 837fac2

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

examples/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ similar API between the different models.
66
| Section | Description |
77
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
88
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks.
9+
| [Running on TPUs](#running-on-tpus) | Examples on running fine-tuning tasks on Google TPUs to accelerate workloads. |
910
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
1011
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
1112
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
@@ -36,6 +37,48 @@ Quick benchmarks from the script (no other modifications):
3637

3738
Mixed precision (AMP) reduces the training time considerably for the same hardware and hyper-parameters (same batch size was used).
3839

40+
## Running on TPUs
41+
42+
You can accelerate your workloads on Google's TPUs. For information on how to setup your TPU environment refer to this
43+
[README](https://github.com/pytorch/xla/blob/master/README.md).
44+
45+
The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
46+
identical to your normal GPU + Huggingface setup.
47+
48+
### GLUE
49+
50+
Before running anyone of these GLUE tasks you should download the
51+
[GLUE data](https://gluebenchmark.com/tasks) by running
52+
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
53+
and unpack it to some directory `$GLUE_DIR`.
54+
55+
For running your GLUE task on MNLI dataset you can run something like the following:
56+
57+
```
58+
export GLUE_DIR=/path/to/glue
59+
export TASK_NAME=MNLI
60+
61+
python run_glue_tpu.py \
62+
--model_type bert \
63+
--model_name_or_path bert-base-cased \
64+
--task_name $TASK_NAME \
65+
--do_train \
66+
--do_eval \
67+
--do_lower_case \
68+
--data_dir $GLUE_DIR/$TASK_NAME \
69+
--max_seq_length 128 \
70+
--train_batch_size 32 \
71+
--learning_rate 3e-5 \
72+
--num_train_epochs 3.0 \
73+
--output_dir /tmp/$TASK_NAME \
74+
--overwrite_output_dir \
75+
--logging_steps 50 \
76+
--save_steps 200 \
77+
--num_cores=8 \
78+
--only_log_master
79+
```
80+
81+
3982
## Language model fine-tuning
4083

4184
Based on the script [`run_lm_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_lm_finetuning.py).

examples/run_glue_tpu.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# coding=utf-8
22
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
3+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
56
# you may not use this file except in compliance with the License.
@@ -77,8 +78,8 @@ def set_seed(args):
7778
def get_sampler(dataset):
7879
if xm.xrt_world_size() <= 1:
7980
return RandomSampler(dataset)
80-
return DistributedSampler(dataset,
81-
num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
81+
return DistributedSampler(
82+
dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
8283

8384

8485
def train(args, train_dataset, model, tokenizer, disable_logging=False):
@@ -97,8 +98,14 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
9798
# Prepare optimizer and schedule (linear warmup and decay)
9899
no_decay = ['bias', 'LayerNorm.weight']
99100
optimizer_grouped_parameters = [
100-
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
101-
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
101+
{
102+
'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
103+
'weight_decay': args.weight_decay,
104+
},
105+
{
106+
'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
107+
'weight_decay': 0.0,
108+
},
102109
]
103110
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
104111
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
@@ -129,8 +136,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
129136
logger.info("Saving model checkpoint to %s", output_dir)
130137
if not os.path.exists(output_dir):
131138
os.makedirs(output_dir)
132-
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
133-
model_to_save.save_pretrained(output_dir, xla_device=True)
139+
model.save_pretrained(output_dir, xla_device=True)
134140
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
135141

136142
model.train()
@@ -144,6 +150,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
144150
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
145151

146152
if args.gradient_accumulation_steps > 1:
153+
xm.mark_step() # Mark step to evaluate graph so far or else graph will grow too big and OOM.
147154
loss = loss / args.gradient_accumulation_steps
148155

149156
loss.backward()
@@ -350,25 +357,24 @@ def main(args):
350357

351358
logger.info("Training/evaluation parameters %s", args)
352359

353-
# Training
354360
if args.do_train:
361+
# Train the model.
355362
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
356363
global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging)
357364
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
358365

366+
# Save trained model.
367+
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
368+
output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal()))
359369

360-
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
361-
output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal()))
362-
if args.do_train:
363370
# Create output directory if needed
364371
if not os.path.exists(output_dir):
365372
os.makedirs(output_dir)
366373

367374
logger.info("Saving model checkpoint to %s", output_dir)
368375
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
369376
# They can then be reloaded using `from_pretrained()`
370-
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
371-
model_to_save.save_pretrained(output_dir, xla_device=True)
377+
model.save_pretrained(output_dir, xla_device=True)
372378
tokenizer.save_pretrained(output_dir)
373379

374380
# Good practice: save your training arguments together with the trained.

0 commit comments

Comments
 (0)