Skip to content

Commit b421758

Browse files
authored
Cleanup TPU bits from run_glue.py (#3)
TPU runner is currently implemented in: https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py. We plan to upstream this directly into `huggingface/transformers` (either `master` or `tpu`) branch once it's been more thoroughly tested.
1 parent e056eff commit b421758

File tree

1 file changed

+1
-32
lines changed

1 file changed

+1
-32
lines changed

examples/run_glue.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def train(args, train_dataset, model, tokenizer):
158158
loss.backward()
159159

160160
tr_loss += loss.item()
161-
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
161+
if (step + 1) % args.gradient_accumulation_steps == 0:
162162
if args.fp16:
163163
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
164164
else:
@@ -189,11 +189,6 @@ def train(args, train_dataset, model, tokenizer):
189189
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
190190
logger.info("Saving model checkpoint to %s", output_dir)
191191

192-
if args.tpu:
193-
args.xla_model.optimizer_step(optimizer, barrier=True)
194-
model.zero_grad()
195-
global_step += 1
196-
197192
if args.max_steps > 0 and global_step > args.max_steps:
198193
epoch_iterator.close()
199194
break
@@ -393,15 +388,6 @@ def main():
393388
parser.add_argument('--seed', type=int, default=42,
394389
help="random seed for initialization")
395390

396-
parser.add_argument('--tpu', action='store_true',
397-
help="Whether to run on the TPU defined in the environment variables")
398-
parser.add_argument('--tpu_ip_address', type=str, default='',
399-
help="TPU IP address if none are set in the environment variables")
400-
parser.add_argument('--tpu_name', type=str, default='',
401-
help="TPU name if none are set in the environment variables")
402-
parser.add_argument('--xrt_tpu_config', type=str, default='',
403-
help="XRT TPU config if none are set in the environment variables")
404-
405391
parser.add_argument('--fp16', action='store_true',
406392
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
407393
parser.add_argument('--fp16_opt_level', type=str, default='O1',
@@ -435,23 +421,6 @@ def main():
435421
args.n_gpu = 1
436422
args.device = device
437423

438-
if args.tpu:
439-
if args.tpu_ip_address:
440-
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
441-
if args.tpu_name:
442-
os.environ["TPU_NAME"] = args.tpu_name
443-
if args.xrt_tpu_config:
444-
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
445-
446-
assert "TPU_IP_ADDRESS" in os.environ
447-
assert "TPU_NAME" in os.environ
448-
assert "XRT_TPU_CONFIG" in os.environ
449-
450-
import torch_xla
451-
import torch_xla.core.xla_model as xm
452-
args.device = xm.xla_device()
453-
args.xla_model = xm
454-
455424
# Setup logging
456425
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
457426
datefmt = '%m/%d/%Y %H:%M:%S',

0 commit comments

Comments
 (0)