Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/ner/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@
from tqdm import tqdm, trange

from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers.modeling_auto import MODEL_MAPPING
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


Expand All @@ -51,8 +50,9 @@

logger = logging.getLogger(__name__)

ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())

TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]

Expand Down Expand Up @@ -384,7 +384,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
)
parser.add_argument(
"--model_name_or_path",
Expand Down