Skip to content

Conversation

@LysandreJik
Copy link
Member

@LysandreJik LysandreJik commented Mar 3, 2020

Currently, encode_plus and batch_encode_plus return the same outputs for different models.

This is sub-optimal as we can't do the following for each model:

inputs = tokenizer.encode_plus(sequence, return_tensors="pt")
model(**inputs)

This will crash for DistilBERT as the tokenizer would return token_type_ids which can't be handled by the model.

In order to fix this, each tokenizer has to return model-specific arguments. Usually there are the same default arguments, and some models handle less (e.g. DistilBERT, RoBERTa).

This is a mock PR offering a solution using a skip_outputs return_outputs argument to tokenizers.

from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
print(tokenizer.encode_plus("Hey, how are you?"))

Returns a dictionary without the token type ids:

{'input_ids': [101, 4403, 117, 1293, 1132, 1128, 136, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

Specifying a custom skip_outputs return_outputs at initialisation works as expected:

from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased", return_outputs=["attention_mask", "token_type_ids"])
print(tokenizer.encode_plus("Hey, how are you?"))
{'input_ids': [101, 4403, 117, 1293, 1132, 1128, 136, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

or with a custom skipped output:

from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased", return_outputs=["token_type_ids"])
print(tokenizer.encode_plus("Hey, how are you?"))
{'input_ids': [101, 4403, 117, 1293, 1132, 1128, 136, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0]}

This also works with saving/reloading:

from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased", return_outputs=["token_type_ids"])
print(tokenizer.encode_plus("Hey, how are you?"))
tokenizer.save_pretrained("xxx")

tokenizer = DistilBertTokenizer.from_pretrained("xxx")
print(tokenizer.encode_plus("Hey, how are you?"))

Returns the following:

{'input_ids': [101, 4403, 117, 1293, 1132, 1128, 136, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0]}
{'input_ids': [101, 4403, 117, 1293, 1132, 1128, 136, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0]}

@LysandreJik LysandreJik requested review from julien-c and thomwolf March 3, 2020 23:28
@thomwolf
Copy link
Member

thomwolf commented Mar 4, 2020

Nice.

One question: do we want to have a skip_output flag or to have a keep_output flag.

skip_output seems to me as introducing a dependency to be maintained between all the models (if we add a model with additional output that are processed by encode_plus later, we would have to update all the models to avoid this output)

keep_output is longer to write right now (we have to add it for all the models) but once it's added, all the models are independent from each others.

@julien-c
Copy link
Member

julien-c commented Mar 4, 2020

I'm ok with both solutions (by the way, in general terms, a lot of software can accept a combination of whitelist and/or blacklist. When both are present, it's usually "include the whitelist, and remove the blacklist")

If we do keep_output, maybe we name the attribute return_outputs: List[str] for consistency with encode_xxx() params?

@LysandreJik
Copy link
Member Author

LysandreJik commented Mar 4, 2020

I agree with both of you. Furthermore, this approach (deleting from the dict encode_plus generated) is not compatible with the return_xxx in the encode_plus arguments.

I'm implementing both your proposed changes, looking into fixing the above and into fast tokenizers.

I'll then move on to the tests.

  • replace the blacklist by a whitelist
  • rename to return_outputs for consistency with encode_plus arguments
  • compatibility with all of encode_plus's arguments
  • fast tokenizers
  • tests

return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_token_type_ids=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe explicitly define them as: Optional[bool] = None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll update the typings, doc & tests if @mfuntowicz and @thomwolf agree that this is the best way to do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we move from True/False to None/XXX ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting None means that we can safely identify if the user has passed a value different from the default. If a value is None, then we can rely on the return_outputs attribute to return this value or not.
If it is not None, then its value is absolute (as it's an argument that was passed by the user).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, then I guess a clear typing like @julien-c suggested would be good

@mfuntowicz
Copy link
Member

mfuntowicz commented Mar 5, 2020

I like the solution, 👍 .

One question: It requires the user to know / look at the names of the parameters handled by __call__() / forward(), should we expose a property on PreTrainedModel to give the list of parameter supported by the model ? This one will be overrided in Roberta and Distil.

model = SomeModel(...)
tokenizer = AutoTokenizer.from_pretrained(..., return_outputs=model.input_names)

@LysandreJik
Copy link
Member Author

LysandreJik commented Mar 5, 2020

Indeed, such an attribute would be helpful! I'll add it and move on to the tests.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok did a review but I'm not sure it was finished enough to do one haha
anyway, here are some comments

return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_token_type_ids=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we move from True/False to None/XXX ?

pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
return_outputs = ["attention_mask"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this model_inputs or model_input_names?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @julien-c answer that as he offered that naming

If we do keep_output, maybe we name the attribute return_outputs: List[str] for consistency with encode_xxx() params?

if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.return_outputs
if return_overflowing_tokens is None:
return_overflowing_tokens = "overflowing_tokens" in self.return_outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"overflowing_tokens" will never be in self.return_outputs, won't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not by default, but a user could pass it as argument so that it is always returned. Not as important as the other three though, I agree. Will remove it.

@thomwolf
Copy link
Member

thomwolf commented Mar 5, 2020

Regarding the suggestion of @mfuntowicz, in the end, this should be in a common configuration for model and tokenizers I guess, so maybe we could actually have this attribute as input to PretrainedTokenizer.__init__() already (instead of class attribute) to prepare for the future.

@LysandreJik
Copy link
Member Author

That value is currently managed by the __init__ method, see the examples above

It still needs to be a class attribute in my opinion, as it should be overridden by children of PreTrainedTokenizer and it should be known by encode_plus/encode/batch_encode_plus.

return [
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
for i in range(len(batch_encode_plus_sequences))
for i in range(len(batch_encode_plus_sequences["input_ids"]))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bugfix

@LysandreJik
Copy link
Member Author

Should be good for review. I reverted the docs commit because it made the review harder. I'll recommit the docs at merge time.

@LysandreJik LysandreJik marked this pull request as ready for review March 6, 2020 17:48
@codecov-io
Copy link

codecov-io commented Mar 6, 2020

Codecov Report

Merging #3116 into master will increase coverage by <.01%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3116      +/-   ##
==========================================
+ Coverage   77.98%   77.99%   +<.01%     
==========================================
  Files          98       98              
  Lines       16645    16660      +15     
==========================================
+ Hits        12981    12994      +13     
- Misses       3664     3666       +2
Impacted Files Coverage Δ
src/transformers/tokenization_utils.py 91.85% <100%> (+0.12%) ⬆️
src/transformers/tokenization_roberta.py 100% <100%> (ø) ⬆️
src/transformers/tokenization_distilbert.py 100% <100%> (ø) ⬆️
src/transformers/file_utils.py 68% <0%> (-0.41%) ⬇️
src/transformers/modeling_utils.py 94.4% <0%> (-0.16%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 49debe6...96b2fa1. Read the comment docs.

@LysandreJik LysandreJik changed the title [MOCK] Skipping outputs - Minimal example Skipping outputs - Minimal example Mar 9, 2020
@LysandreJik LysandreJik changed the title Skipping outputs - Minimal example Skipping outputs Mar 9, 2020
@LysandreJik LysandreJik merged commit 5164ea9 into master Mar 9, 2020
@LysandreJik LysandreJik deleted the skip-outputs branch March 9, 2020 17:49
@LysandreJik
Copy link
Member Author

LysandreJik commented Mar 9, 2020

Merged after offline review from @thomwolf and @julien-c

jplu pushed a commit to jplu/transformers that referenced this pull request Mar 25, 2020
* Minimal example

* Proposal 2

* Proposal 2 for fast tokenizers

* Typings

* Docs

* Revert "Docs" for easier review

This reverts commit eaf0f97.

* Remove unnecessary assignments

* Tests

* Fix faulty type

* Remove prints

* return_outputs -> model_input_names

* Revert "Revert "Docs" for easier review"

This reverts commit 6fdc694.

* code quality
@LysandreJik LysandreJik mentioned this pull request Mar 25, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants