Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def generate(
do_sample = do_sample if do_sample is not None else self.config.do_sample

if do_sample is False or num_beams == 1:
seed = model_kwargs.pop("seed", None)
return self._generate(
input_ids=input_ids,
max_length=max_length,
Expand All @@ -601,13 +602,14 @@ def generate(
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
seed=model_kwargs.pop("seed", None),
seed=seed,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
**model_kwargs,
)

# We cannot generate if the model does not have a LM head
Expand Down Expand Up @@ -1288,6 +1290,29 @@ def adjust_logits_during_generation(
else:
return logits

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for PyTorch (here), with self.forward replaced with self.call

"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.call).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

def _generate(
self,
input_ids=None,
Expand Down Expand Up @@ -1483,6 +1508,9 @@ def _generate(
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate model kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, fine with me! Can also increase all numbers otherwise

self._validate_model_kwargs(model_kwargs.copy())

# 1. Set generation parameters if not already defined
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
Expand Down
183 changes: 183 additions & 0 deletions tests/generation/test_generation_tf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import unittest

from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow


if is_tf_available():
import tensorflow as tf

from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering


@require_tf
class UtilsFunctionsTest(unittest.TestCase):

# tests whether the top_k_top_p_filtering function behaves as expected
def test_top_k_top_p_filtering(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved from test_modeling_tf_common, no changes

logits = tf.convert_to_tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276, # 5th highest value; idx. 9
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cummulative prob of 5 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958, # 5th highest value; idx. 18
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cummulative prob of 5 highest values <= 0.6
],
dtype=tf.float32,
)

non_inf_expected_idx = tf.convert_to_tensor(
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
dtype=tf.int32,
) # expected non filtered idx as noted above

non_inf_expected_output = tf.convert_to_tensor(
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
dtype=tf.float32,
) # expected non filtered values as noted above

output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)

non_inf_output = output[output != -float("inf")]
non_inf_idx = tf.cast(
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))),
dtype=tf.int32,
)

tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)


@require_tf
class TFGenerationIntegrationTests(unittest.TestCase):
@slow
def test_generate_tf_function_export(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved from test_modeling_tf_common, added the @slow (takes >30s)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that makes sense (similar to PyTorch).
Also just FYI in PyTorch we're testing currently much more than in TF mainly because we've allowed to return hidden_states and attentios. We could do the same for TF at some point

test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
max_length = 2

class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model

@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}

dummy_input_ids = [[2, 0], [102, 103]]
dummy_attention_masks = [[1, 0], [1, 1]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for batch_size in range(1, len(dummy_input_ids) + 1):
inputs = {
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
4 changes: 2 additions & 2 deletions tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,8 +2704,8 @@ def test_constrained_beam_search_mixin_type_checks(self):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is not relevant for the test, but not using a model from hf-internal-testing was an oversight in the previous PR :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes def a good idea!

model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
Expand Down
Loading