2929import sys
3030import time
3131import warnings
32+ import torch_xla .debug .profiler as xp
3233from collections .abc import Mapping
3334from pathlib import Path
35+ from threading import Thread
3436from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
3537
3638
162164 import datasets
163165
164166if is_torch_tpu_available (check_device = False ):
167+ import torch_xla
165168 import torch_xla .core .xla_model as xm
166169 import torch_xla .debug .metrics as met
167170
@@ -838,7 +841,8 @@ def get_train_dataloader(self) -> DataLoader:
838841 dataloader_params ["drop_last" ] = self .args .dataloader_drop_last
839842 dataloader_params ["worker_init_fn" ] = seed_worker
840843
841- return self .accelerator .prepare (DataLoader (train_dataset , ** dataloader_params ))
844+ # TODO(jonbolin): Disabling Accelerate on the dataloader (`Unknown device SPMD:0`)
845+ return DataLoader (train_dataset , ** dataloader_params )
842846
843847 def _get_eval_sampler (self , eval_dataset : Dataset ) -> Optional [torch .utils .data .Sampler ]:
844848 # Deprecated code
@@ -1444,6 +1448,21 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
14441448
14451449 return model
14461450
1451+ def _xla_sharded_dataloader (self , dataloader ):
1452+ if is_torch_tpu_available ():
1453+ sharding_spec = None
1454+ if self .args .spmd_batch_sharding :
1455+ import torch_xla .experimental .xla_sharding as xs
1456+ import torch_xla .runtime as xr
1457+ import torch_xla .distributed .parallel_loader as pl
1458+ num_devices = xr .global_device_count ()
1459+ device_ids = np .arange (num_devices )
1460+ mesh = xs .Mesh (device_ids , (num_devices , 1 ))
1461+ sharding_spec = xs .ShardingSpec (mesh , (0 , 1 ))
1462+ return pl .MpDeviceLoader (dataloader , self .args .device , input_sharding = sharding_spec , loader_prefetch_size = self .args .train_batch_size , device_prefetch_size = 4 )
1463+ else :
1464+ return dataloader
1465+
14471466 def train (
14481467 self ,
14491468 resume_from_checkpoint : Optional [Union [str , bool ]] = None ,
@@ -1537,7 +1556,7 @@ def _inner_training_loop(
15371556 self ._train_batch_size = batch_size
15381557 logger .debug (f"Currently training with a batch size of: { self ._train_batch_size } " )
15391558 # Data loader and number of training steps
1540- train_dataloader = self .get_train_dataloader ()
1559+ train_dataloader = self ._xla_sharded_dataloader ( self . get_train_dataloader () )
15411560
15421561 # Setting up training control variables:
15431562 # number of training epochs: num_train_epochs
@@ -1771,7 +1790,13 @@ def _inner_training_loop(
17711790 rng_to_sync = True
17721791
17731792 step = - 1
1793+ profile_step = int (os .environ .get ('PROFILE_STEP' , - 1 ))
1794+ profile_epoch = int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
1795+ profile_duration = int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
1796+ profile_logdir = os .environ .get ('PROFILE_LOGDIR' , None )
17741797 for step , inputs in enumerate (epoch_iterator ):
1798+ if step == 0 and epoch == 0 :
1799+ print ('input sharding' , {k : (v .shape , torch_xla ._XLAC ._get_xla_sharding_spec (v )) for k , v in inputs .items ()})
17751800 total_batched_samples += 1
17761801 if rng_to_sync :
17771802 self ._load_rng_state (resume_from_checkpoint )
@@ -1792,6 +1817,10 @@ def _inner_training_loop(
17921817 if step % args .gradient_accumulation_steps == 0 :
17931818 self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
17941819
1820+ if step == profile_step and epoch == profile_epoch :
1821+ trace = lambda : xp .trace ('127.0.0.1:9012' , profile_logdir or tempfile .mkdtemp (), profile_duration or 20000 )
1822+ Thread (target = trace ).start ()
1823+
17951824 with self .accelerator .accumulate (model ):
17961825 tr_loss_step = self .training_step (model , inputs )
17971826
@@ -2199,7 +2228,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
21992228 self .log (logs )
22002229
22012230 metrics = None
2202- if self .control .should_evaluate :
2231+ # TODO(jonbolin): Disabling eval loop
2232+ if False : # self.control.should_evaluate:
22032233 if isinstance (self .eval_dataset , dict ):
22042234 metrics = {}
22052235 for eval_dataset_name , eval_dataset in self .eval_dataset .items ():
@@ -2914,7 +2944,7 @@ def evaluate(
29142944 # memory metrics - must set up as early as possible
29152945 self ._memory_tracker .start ()
29162946
2917- eval_dataloader = self .get_eval_dataloader (eval_dataset )
2947+ eval_dataloader = self ._xla_sharded_dataloader ( self . get_eval_dataloader (eval_dataset ) )
29182948 start_time = time .time ()
29192949
29202950 eval_loop = self .prediction_loop if self .args .use_legacy_prediction_loop else self .evaluation_loop
0 commit comments