Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def finalize(
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
Expand All @@ -341,7 +342,7 @@ def finalize(
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id

return UserDict(
Expand Down
42 changes: 42 additions & 0 deletions tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,48 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):

self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))

@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
# failed after directly copying from it. Refer to PR #15555
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
batch_size=1,
num_beams=num_beams,
device=model.device,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)

outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

self.assertListEqual(outputs, ["Wie alt bist du?"])

@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
Expand Down