Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, **kwargs):
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def generate(
input_ids=None,
max_length=None,
do_sample=True,
early_stopping=False,
num_beams=None,
temperature=None,
top_k=None,
Expand Down Expand Up @@ -559,11 +560,12 @@ def generate(
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`)"
)

max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
Expand All @@ -586,6 +588,7 @@ def generate(

assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
Expand Down Expand Up @@ -662,6 +665,7 @@ def generate(
cur_len,
max_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
Expand Down Expand Up @@ -803,6 +807,7 @@ def _generate_beam_search(
cur_len,
max_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
Expand All @@ -820,7 +825,8 @@ def _generate_beam_search(

# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]

# scores for each sentence in the beam
Expand Down Expand Up @@ -1058,10 +1064,11 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
logit_penalties = np.zeros(logit_penalized.shape)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized[logit_penalized < 0] = repetition_penalty
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalized)
logit_penalties[logit_penalized < 0] = repetition_penalty
logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)


Expand Down