Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions examples/ner/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,15 @@
from tqdm import tqdm, trange

from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertForTokenClassification,
AlbertTokenizer,
BertConfig,
BertForTokenClassification,
BertTokenizer,
CamembertConfig,
CamembertForTokenClassification,
CamembertTokenizer,
DistilBertConfig,
DistilBertForTokenClassification,
DistilBertTokenizer,
RobertaConfig,
RobertaForTokenClassification,
RobertaTokenizer,
XLMRobertaConfig,
XLMRobertaForTokenClassification,
XLMRobertaTokenizer,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers.modeling_auto import MODEL_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


Expand All @@ -64,22 +51,8 @@

logger = logging.getLogger(__name__)

ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
),
(),
)

MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
}
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)

TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]

Expand Down Expand Up @@ -411,7 +384,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
)
parser.add_argument(
"--model_name_or_path",
Expand Down Expand Up @@ -594,8 +567,7 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
id2label={str(i): label for i, label in enumerate(labels)},
Expand All @@ -604,12 +576,12 @@ def main():
)
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
logger.info("Tokenizer arguments: %s", tokenizer_args)
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
**tokenizer_args,
)
model = model_class.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
Expand Down Expand Up @@ -650,7 +622,7 @@ def main():
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
Expand All @@ -660,7 +632,7 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelForTokenClassification.from_pretrained(checkpoint)
model.to(args.device)
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
if global_step:
Expand All @@ -672,8 +644,8 @@ def main():
writer.write("{} = {}\n".format(key, str(results[key])))

if args.do_predict and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
model = model_class.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
model = AutoModelForTokenClassification.from_pretrained(args.output_dir)
model.to(args.device)
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
# Save results
Expand Down