diff --git a/examples/ner/run_ner.py b/examples/ner/run_ner.py index e4bc01a45e52..54decd6e02ba 100644 --- a/examples/ner/run_ner.py +++ b/examples/ner/run_ner.py @@ -31,7 +31,6 @@ from tqdm import tqdm, trange from transformers import ( - ALL_PRETRAINED_MODEL_ARCHIVE_MAP, WEIGHTS_NAME, AdamW, AutoConfig, @@ -39,7 +38,7 @@ AutoTokenizer, get_linear_schedule_with_warmup, ) -from transformers.modeling_auto import MODEL_MAPPING +from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file @@ -51,8 +50,9 @@ logger = logging.getLogger(__name__) -ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP) -MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ()) TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"] @@ -384,7 +384,7 @@ def main(): default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES), + help="Model type selected in the list: " + ", ".join(MODEL_TYPES), ) parser.add_argument( "--model_name_or_path",