-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
Description
The pipeline for QA crashes for roberta models.
It's loading the model and tokenizer correctly, but the SQuAD preprocessing produces a wrong p_mask leading to no possible prediction and the error message below.
The observed p_mask for a roberta model is
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]
while it should only mask the question tokens like this
[0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...]
I think the deeper root cause here is that roberta's token_type_ids returned from encode_plus are now all zeros (introduced in #2432) and the creation of p_mask in squad_convert_example_to_features relies on this information:
transformers/src/transformers/data/processors/squad.py
Lines 189 to 202 in 520e7f2
| # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) | |
| # Original TF implem also keep the classification token (set to 0) (not sure why...) | |
| p_mask = np.array(span["token_type_ids"]) | |
| p_mask = np.minimum(p_mask, 1) | |
| if tokenizer.padding_side == "right": | |
| # Limit positive values to one | |
| p_mask = 1 - p_mask | |
| p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 | |
| # Set the CLS index to '0' | |
| p_mask[cls_index] = 0 |
Haven't checked yet, but this might also affect training/eval if
p_mask is used there.
How to reproduce?
model_name = "deepset/roberta-base-squad2"
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
res = nlp({
'question': 'What is roberta?',
'context': 'Roberta is a language model that was trained for a longer time, on more data, without NSP'
})
results in
File "/home/mp/deepset/dev/transformers/src/transformers/pipelines.py", line 847, in __call__
for s, e, score in zip(starts, ends, scores)
File "/home/mp/deepset/dev/transformers/src/transformers/pipelines.py", line 847, in <listcomp>
for s, e, score in zip(starts, ends, scores)
KeyError: 0
Environment
- Ubuntu 18.04
- Python 3.7.6
- PyTorch 1.3.1