Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,21 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging)
for step, batch in enumerate(epoch_iterator):

# Save model checkpoint.
if args.save_steps > 0 and global_step % args.save_steps == 0:
output_dir = os.path.join(args.output_dir, 'checkpoint-{}-xla{}'.format(global_step, xm.get_ordinal()))
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
logger.info("Saving model checkpoint to %s", output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

if xm.is_master_ordinal():
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))

# Barrier to wait for saving checkpoint.
xm.rendezvous('mid_training_checkpoint')
# model.save_pretrained needs to be called by all ordinals
model.save_pretrained(output_dir, xla_device=True)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))

model.train()
inputs = {'input_ids': batch[0],
Expand Down Expand Up @@ -263,15 +270,19 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):


def load_and_cache_examples(args, task, tokenizer, evaluate=False):
if not xm.is_master_ordinal():
xm.rendezvous('load_and_cache_examples')

processor = processors[task]()
output_mode = output_modes[task]
cached_features_file = os.path.join(
args.data_dir, 'cached_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length),
str(task)))

# Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_xla{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length),
str(task),
xm.get_ordinal()))
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
Expand All @@ -294,6 +305,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

if xm.is_master_ordinal():
xm.rendezvous('load_and_cache_examples')

# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
Expand Down Expand Up @@ -341,6 +355,9 @@ def main(args):
label_list = processor.get_labels()
num_labels = len(label_list)

if not xm.is_master_ordinal():
xm.rendezvous('download_only_once') # Make sure only the first process in distributed training will download model & vocab

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming no tensor work is done in the methods b/w download_only_once rendezvous's

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, correct.


# Load pretrained model and tokenizer
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
Expand All @@ -359,47 +376,53 @@ def main(args):
cache_dir=args.cache_dir if args.cache_dir else None,
xla_device=True)

if xm.is_master_ordinal():
xm.rendezvous('download_only_once')

# Send model to TPU/XLA device.
model.to(args.device)

logger.info("Training/evaluation parameters %s", args)

output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal()))
if args.do_train:
# Train the model.
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

# Save trained model.
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()

# Create output directory if needed
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if xm.is_master_ordinal():
# Save trained model.
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()

# Create output directory if needed
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model.save_pretrained(output_dir, xla_device=True)
tokenizer.save_pretrained(output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained.
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

# Good practice: save your training arguments together with the trained.
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
xm.rendezvous('post_training_checkpoint')
# model.save_pretrained needs to be called by all ordinals
model.save_pretrained(args.output_dir, xla_device=True)

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(output_dir, xla_device=True)
tokenizer = tokenizer_class.from_pretrained(output_dir, xla_device=True)
model = model_class.from_pretrained(args.output_dir, xla_device=True)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, xla_device=True)
model.to(args.device)


# Evaluation
results = {}
if args.do_eval:
tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=args.do_lower_case, xla_device=True)
checkpoints = [output_dir]
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, xla_device=True)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
Expand Down
4 changes: 1 addition & 3 deletions transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
xla_device = kwargs.pop('xla_device', False)

if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
Expand All @@ -132,8 +131,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, xla_device=xla_device)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
Expand Down
15 changes: 5 additions & 10 deletions transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def docstring_decorator(fn):
return fn
return docstring_decorator

def url_to_filename(url, etag=None, xla_device=False):
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
Expand All @@ -120,10 +120,6 @@ def url_to_filename(url, etag=None, xla_device=False):
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()

if xla_device:
import torch_xla.core.xla_model as xm
filename += '.xla' + str(xm.get_ordinal())

if url.endswith('.h5'):
filename += '.h5'

Expand Down Expand Up @@ -156,7 +152,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag


def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, xla_device=False):
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
Expand All @@ -177,8 +173,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N

if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, xla_device=xla_device)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
Expand Down Expand Up @@ -252,7 +247,7 @@ def http_get(url, temp_file, proxies=None):


def get_from_cache(url, cache_dir=None, force_download=False, proxies=None,
etag_timeout=10, xla_device=False):
etag_timeout=10):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
Expand Down Expand Up @@ -282,7 +277,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None,

if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
filename = url_to_filename(url, etag, xla_device=xla_device)
filename = url_to_filename(url, etag)

# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
Expand Down
16 changes: 9 additions & 7 deletions transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,20 @@ def save_pretrained(self, save_directory, xla_device=False):
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self

# Save configuration file
model_to_save.config.save_pretrained(save_directory)

# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)

# Save configuration file
if xla_device:
# Saving for each process is save since output_dir is proc# namespaced.
import torch_xla.core.xla_model as xm
xm.save(model_to_save.state_dict(), output_model_file, master_only=False)
if xm.is_master_ordinal():
model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master
xm.save(model_to_save.state_dict(), output_model_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea for the future; you could use the checkpoint tagger later, if they mark which chpt is the best etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, they don't have such tagger but let me see if they do. Good idea we could even upstream if they don't. Thanks.

else:
model_to_save.config.save_pretrained(save_directory)
torch.save(model_to_save.state_dict(), output_model_file)

logger.info("Model weights saved in {}".format(output_model_file))

@classmethod
Expand Down Expand Up @@ -371,8 +374,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, xla_device=xla_device)
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
Expand Down
5 changes: 2 additions & 3 deletions transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs)
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
xla_device = kwargs.pop('xla_device', False)

s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
Expand Down Expand Up @@ -343,7 +342,7 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs)
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path,
pretrained_model_name_or_path,
list(cls.vocab_files_names.values())))

# Get files from url, cache, or disk depending on the case
Expand All @@ -354,7 +353,7 @@ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs)
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir,
force_download=force_download, proxies=proxies, xla_device=xla_device)
force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
msg = "Couldn't reach server at '{}' to download vocabulary files."
Expand Down