diff --git a/examples/ner/run_ner.py b/examples/ner/run_ner.py index c32b3af226e9..e4bc01a45e52 100644 --- a/examples/ner/run_ner.py +++ b/examples/ner/run_ner.py @@ -31,28 +31,15 @@ from tqdm import tqdm, trange from transformers import ( + ALL_PRETRAINED_MODEL_ARCHIVE_MAP, WEIGHTS_NAME, AdamW, - AlbertConfig, - AlbertForTokenClassification, - AlbertTokenizer, - BertConfig, - BertForTokenClassification, - BertTokenizer, - CamembertConfig, - CamembertForTokenClassification, - CamembertTokenizer, - DistilBertConfig, - DistilBertForTokenClassification, - DistilBertTokenizer, - RobertaConfig, - RobertaForTokenClassification, - RobertaTokenizer, - XLMRobertaConfig, - XLMRobertaForTokenClassification, - XLMRobertaTokenizer, + AutoConfig, + AutoModelForTokenClassification, + AutoTokenizer, get_linear_schedule_with_warmup, ) +from transformers.modeling_auto import MODEL_MAPPING from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file @@ -64,22 +51,8 @@ logger = logging.getLogger(__name__) -ALL_MODELS = sum( - ( - tuple(conf.pretrained_config_archive_map.keys()) - for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig) - ), - (), -) - -MODEL_CLASSES = { - "albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer), - "bert": (BertConfig, BertForTokenClassification, BertTokenizer), - "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), - "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer), - "camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer), - "xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer), -} +ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP) +MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING) TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"] @@ -411,7 +384,7 @@ def main(): default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES), ) parser.add_argument( "--model_name_or_path", @@ -594,8 +567,7 @@ def main(): torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained( + config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, id2label={str(i): label for i, label in enumerate(labels)}, @@ -604,12 +576,12 @@ def main(): ) tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS} logger.info("Tokenizer arguments: %s", tokenizer_args) - tokenizer = tokenizer_class.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, **tokenizer_args, ) - model = model_class.from_pretrained( + model = AutoModelForTokenClassification.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -650,7 +622,7 @@ def main(): # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: - tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( @@ -660,7 +632,7 @@ def main(): logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" - model = model_class.from_pretrained(checkpoint) + model = AutoModelForTokenClassification.from_pretrained(checkpoint) model.to(args.device) result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step) if global_step: @@ -672,8 +644,8 @@ def main(): writer.write("{} = {}\n".format(key, str(results[key]))) if args.do_predict and args.local_rank in [-1, 0]: - tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args) - model = model_class.from_pretrained(args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args) + model = AutoModelForTokenClassification.from_pretrained(args.output_dir) model.to(args.device) result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test") # Save results