-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Generate: validate model_kwargs on TF (and catch typos in generate arguments)
#18651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -579,6 +579,7 @@ def generate( | |
| do_sample = do_sample if do_sample is not None else self.config.do_sample | ||
|
|
||
| if do_sample is False or num_beams == 1: | ||
| seed = model_kwargs.pop("seed", None) | ||
| return self._generate( | ||
| input_ids=input_ids, | ||
| max_length=max_length, | ||
|
|
@@ -601,13 +602,14 @@ def generate( | |
| attention_mask=attention_mask, | ||
| decoder_start_token_id=decoder_start_token_id, | ||
| use_cache=use_cache, | ||
| seed=model_kwargs.pop("seed", None), | ||
| seed=seed, | ||
| output_scores=output_scores, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict_in_generate=return_dict_in_generate, | ||
| forced_bos_token_id=forced_bos_token_id, | ||
| forced_eos_token_id=forced_eos_token_id, | ||
| **model_kwargs, | ||
| ) | ||
|
|
||
| # We cannot generate if the model does not have a LM head | ||
|
|
@@ -1288,6 +1290,29 @@ def adjust_logits_during_generation( | |
| else: | ||
| return logits | ||
|
|
||
| def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): | ||
| """Validates model kwargs for generation. Generate argument typos will also be caught here.""" | ||
| # Excludes arguments that are handled before calling any model function | ||
| if self.config.is_encoder_decoder: | ||
| for key in ["decoder_input_ids"]: | ||
| model_kwargs.pop(key, None) | ||
|
|
||
| unused_model_args = [] | ||
| model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) | ||
| # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If | ||
| # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) | ||
| if "kwargs" in model_args: | ||
| model_args |= set(inspect.signature(self.call).parameters) | ||
| for key, value in model_kwargs.items(): | ||
| if value is not None and key not in model_args: | ||
| unused_model_args.append(key) | ||
|
|
||
| if unused_model_args: | ||
| raise ValueError( | ||
| f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" | ||
| " generate arguments will also show up in this list)" | ||
| ) | ||
|
|
||
| def _generate( | ||
| self, | ||
| input_ids=None, | ||
|
|
@@ -1483,6 +1508,9 @@ def _generate( | |
| # generate sequences without allowing bad_words to be generated | ||
| outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) | ||
| ```""" | ||
| # 0. Validate model kwargs | ||
|
||
| self._validate_model_kwargs(model_kwargs.copy()) | ||
|
|
||
| # 1. Set generation parameters if not already defined | ||
| length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | ||
| early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2022 The HuggingFace Team Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a clone of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| from transformers import is_tf_available | ||
| from transformers.testing_utils import require_tf, slow | ||
|
|
||
|
|
||
| if is_tf_available(): | ||
| import tensorflow as tf | ||
|
|
||
| from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering | ||
|
|
||
|
|
||
| @require_tf | ||
| class UtilsFunctionsTest(unittest.TestCase): | ||
|
|
||
| # tests whether the top_k_top_p_filtering function behaves as expected | ||
| def test_top_k_top_p_filtering(self): | ||
|
||
| logits = tf.convert_to_tensor( | ||
| [ | ||
| [ | ||
| 8.2220991, # 3rd highest value; idx. 0 | ||
| -0.5620044, | ||
| 5.23229752, | ||
| 4.0386393, | ||
| -6.8798378, | ||
| -0.54785802, | ||
| -3.2012153, | ||
| 2.92777176, | ||
| 1.88171953, | ||
| 7.35341276, # 5th highest value; idx. 9 | ||
| 8.43207833, # 2nd highest value; idx. 10 | ||
| -9.85711836, | ||
| -5.96209236, | ||
| -1.13039161, | ||
| -7.1115294, | ||
| -0.8369633, | ||
| -5.3186408, | ||
| 7.06427407, | ||
| 0.81369344, | ||
| -0.82023817, | ||
| -5.9179796, | ||
| 0.58813443, | ||
| -6.99778438, | ||
| 4.71551189, | ||
| -0.18771637, | ||
| 7.44020759, # 4th highest value; idx. 25 | ||
| 9.38450987, # 1st highest value; idx. 26 | ||
| 2.12662941, | ||
| -9.32562038, | ||
| 2.35652522, | ||
| ], # cummulative prob of 5 highest values <= 0.6 | ||
| [ | ||
| 0.58425518, | ||
| 4.53139238, | ||
| -5.57510464, | ||
| -6.28030699, | ||
| -7.19529503, | ||
| -4.02122551, | ||
| 1.39337037, | ||
| -6.06707057, | ||
| 1.59480517, | ||
| -9.643119, | ||
| 0.03907799, | ||
| 0.67231762, | ||
| -8.88206726, | ||
| 6.27115922, # 4th highest value; idx. 13 | ||
| 2.28520723, | ||
| 4.82767506, | ||
| 4.30421368, | ||
| 8.8275313, # 2nd highest value; idx. 17 | ||
| 5.44029958, # 5th highest value; idx. 18 | ||
| -4.4735794, | ||
| 7.38579536, # 3rd highest value; idx. 20 | ||
| -2.91051663, | ||
| 2.61946077, | ||
| -2.5674762, | ||
| -9.48959302, | ||
| -4.02922645, | ||
| -1.35416918, | ||
| 9.67702323, # 1st highest value; idx. 27 | ||
| -5.89478553, | ||
| 1.85370467, | ||
| ], # cummulative prob of 5 highest values <= 0.6 | ||
| ], | ||
| dtype=tf.float32, | ||
| ) | ||
|
|
||
| non_inf_expected_idx = tf.convert_to_tensor( | ||
| [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], | ||
| dtype=tf.int32, | ||
| ) # expected non filtered idx as noted above | ||
|
|
||
| non_inf_expected_output = tf.convert_to_tensor( | ||
| [8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023], | ||
| dtype=tf.float32, | ||
| ) # expected non filtered values as noted above | ||
|
|
||
| output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) | ||
|
|
||
| non_inf_output = output[output != -float("inf")] | ||
| non_inf_idx = tf.cast( | ||
| tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), | ||
| dtype=tf.int32, | ||
| ) | ||
|
|
||
| tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) | ||
| tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) | ||
|
|
||
|
|
||
| @require_tf | ||
| class TFGenerationIntegrationTests(unittest.TestCase): | ||
| @slow | ||
| def test_generate_tf_function_export(self): | ||
|
||
| test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
| max_length = 2 | ||
|
|
||
| class DummyModel(tf.Module): | ||
| def __init__(self, model): | ||
| super(DummyModel, self).__init__() | ||
| self.model = model | ||
|
|
||
| @tf.function( | ||
| input_signature=( | ||
| tf.TensorSpec((None, max_length), tf.int32, name="input_ids"), | ||
| tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"), | ||
| ), | ||
| jit_compile=True, | ||
| ) | ||
| def serving(self, input_ids, attention_mask): | ||
| outputs = self.model.generate( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| max_new_tokens=max_length, | ||
| return_dict_in_generate=True, | ||
| ) | ||
| return {"sequences": outputs["sequences"]} | ||
|
|
||
| dummy_input_ids = [[2, 0], [102, 103]] | ||
| dummy_attention_masks = [[1, 0], [1, 1]] | ||
| dummy_model = DummyModel(model=test_model) | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving}) | ||
| serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"] | ||
| for batch_size in range(1, len(dummy_input_ids) + 1): | ||
| inputs = { | ||
| "input_ids": tf.constant(dummy_input_ids[:batch_size]), | ||
| "attention_mask": tf.constant(dummy_attention_masks[:batch_size]), | ||
| } | ||
| tf_func_outputs = serving_func(**inputs)["sequences"] | ||
| tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) | ||
| tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) | ||
|
|
||
| def test_validate_generation_inputs(self): | ||
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
| model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
|
|
||
| encoder_input_str = "Hello world" | ||
| input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids | ||
|
|
||
| # typos are quickly detected (the correct argument is `do_sample`) | ||
| with self.assertRaisesRegex(ValueError, "do_samples"): | ||
| model.generate(input_ids, do_samples=True) | ||
|
|
||
| # arbitrary arguments that will not be used anywhere are also not accepted | ||
| with self.assertRaisesRegex(ValueError, "foo"): | ||
| fake_model_kwargs = {"foo": "bar"} | ||
| model.generate(input_ids, **fake_model_kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2704,8 +2704,8 @@ def test_constrained_beam_search_mixin_type_checks(self): | |
| model.generate(input_ids, force_words_ids=[[[-1]]]) | ||
|
|
||
| def test_validate_generation_inputs(self): | ||
| tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") | ||
| model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") | ||
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
|
||
| model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") | ||
|
|
||
| encoder_input_str = "Hello world" | ||
| input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for PyTorch (here), with
self.forwardreplaced withself.call