Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions examples/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
Expand Down Expand Up @@ -155,8 +156,8 @@ def _create_examples(self, lines, set_type):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8])
text_b = line[9])
text_a = line[8]
text_b = line[9]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
Expand Down Expand Up @@ -482,7 +483,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

# Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
if args.fp16:
model.half()
Expand All @@ -507,10 +508,13 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
t_total=t_total)

global_step = 0
if args.do_train:
Expand Down Expand Up @@ -571,7 +575,7 @@ def main():
model.zero_grad()
global_step += 1

if args.do_eval:
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer)
Expand All @@ -583,10 +587,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

model.eval()
Expand Down
44 changes: 31 additions & 13 deletions examples/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import math
import os
import random
import pickle
from tqdm import tqdm, trange

import numpy as np
Expand All @@ -35,6 +36,7 @@
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
Expand Down Expand Up @@ -749,6 +751,10 @@ def main():
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
default=True,
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank",
type=int,
default=-1,
Expand Down Expand Up @@ -845,20 +851,34 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
t_total=t_total)

global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
train_features = None
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
train_features = pickle.dump(train_features, writer)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
Expand Down Expand Up @@ -913,7 +933,7 @@ def main():
model.zero_grad()
global_step += 1

if args.do_predict:
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False)
eval_features = convert_examples_to_features(
Expand All @@ -934,10 +954,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

model.eval()
Expand Down