@@ -158,7 +158,7 @@ def train(args, train_dataset, model, tokenizer):
158158 loss .backward ()
159159
160160 tr_loss += loss .item ()
161- if (step + 1 ) % args .gradient_accumulation_steps == 0 and not args . tpu :
161+ if (step + 1 ) % args .gradient_accumulation_steps == 0 :
162162 if args .fp16 :
163163 torch .nn .utils .clip_grad_norm_ (amp .master_params (optimizer ), args .max_grad_norm )
164164 else :
@@ -189,11 +189,6 @@ def train(args, train_dataset, model, tokenizer):
189189 torch .save (args , os .path .join (output_dir , 'training_args.bin' ))
190190 logger .info ("Saving model checkpoint to %s" , output_dir )
191191
192- if args .tpu :
193- args .xla_model .optimizer_step (optimizer , barrier = True )
194- model .zero_grad ()
195- global_step += 1
196-
197192 if args .max_steps > 0 and global_step > args .max_steps :
198193 epoch_iterator .close ()
199194 break
@@ -393,15 +388,6 @@ def main():
393388 parser .add_argument ('--seed' , type = int , default = 42 ,
394389 help = "random seed for initialization" )
395390
396- parser .add_argument ('--tpu' , action = 'store_true' ,
397- help = "Whether to run on the TPU defined in the environment variables" )
398- parser .add_argument ('--tpu_ip_address' , type = str , default = '' ,
399- help = "TPU IP address if none are set in the environment variables" )
400- parser .add_argument ('--tpu_name' , type = str , default = '' ,
401- help = "TPU name if none are set in the environment variables" )
402- parser .add_argument ('--xrt_tpu_config' , type = str , default = '' ,
403- help = "XRT TPU config if none are set in the environment variables" )
404-
405391 parser .add_argument ('--fp16' , action = 'store_true' ,
406392 help = "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" )
407393 parser .add_argument ('--fp16_opt_level' , type = str , default = 'O1' ,
@@ -435,23 +421,6 @@ def main():
435421 args .n_gpu = 1
436422 args .device = device
437423
438- if args .tpu :
439- if args .tpu_ip_address :
440- os .environ ["TPU_IP_ADDRESS" ] = args .tpu_ip_address
441- if args .tpu_name :
442- os .environ ["TPU_NAME" ] = args .tpu_name
443- if args .xrt_tpu_config :
444- os .environ ["XRT_TPU_CONFIG" ] = args .xrt_tpu_config
445-
446- assert "TPU_IP_ADDRESS" in os .environ
447- assert "TPU_NAME" in os .environ
448- assert "XRT_TPU_CONFIG" in os .environ
449-
450- import torch_xla
451- import torch_xla .core .xla_model as xm
452- args .device = xm .xla_device ()
453- args .xla_model = xm
454-
455424 # Setup logging
456425 logging .basicConfig (format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' ,
457426 datefmt = '%m/%d/%Y %H:%M:%S' ,
0 commit comments