Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def __init__(self, **kwargs):
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand All @@ -94,6 +91,16 @@ def __init__(self, **kwargs):
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())

# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be kwargs.pop("task_specific_params", {}) so we can directly test its keys?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess it's the same pro/con discussion than the one above about prefix = kwargs.pop("prefix", "") or kwargs.pop("prefix", "") . If it's not None as a default it would at the moment be serialized and saved in all configs that don't have task_specific_params or am I missing something?


# Additional attributes without default values
for key, value in kwargs.items():
try:
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,9 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)

if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
Expand All @@ -635,9 +637,6 @@ def generate(
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
Expand Down Expand Up @@ -708,8 +707,12 @@ def generate(
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)

if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id

assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

Expand Down
14 changes: 9 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,9 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)

if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
Expand All @@ -831,9 +833,6 @@ def generate(
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
Expand Down Expand Up @@ -912,7 +911,12 @@ def generate(
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)

if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id

assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

Expand Down