From c621e6671aae88c0695b735a07f792639c57c8f7 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Sun, 8 Oct 2023 21:30:49 +0300 Subject: [PATCH 01/29] add early stopping logits processor --- src/transformers/generation/logits_process.py | 23 +++++++++++++++++++ .../bark/generation_configuration_bark.py | 4 ++++ src/transformers/models/bark/modeling_bark.py | 5 ++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 14f772ab6c99..f02507d5f08d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1702,3 +1702,26 @@ def __call__(self, input_ids, scores): unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits return out + + +class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): + r"""This processor will set every tokens' log probability other than the EOS token to `-inf` + when the probabiliy of the EOS token id is superior to min_eos_p. + + Args: + generate_config (`GenerateConfig`): + The generate config used to generate the output. The following parameters are required: + eos_token_id (`int`, *optional*, defaults to 50257): + The id of the *end-of-sequence* token. + min_eos_p (`float`, *optional*, defaults to None): + Minimum end of speech threshold. + """ + + def __init__(self, generate_config): + self.eos_token_id = generate_config.eos_token_id + self.min_eos_p = generate_config.min_eos_p + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: + scores[:, self.eos_token_id] = -float("inf") + return scores diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 92d836333935..e2df243d2d97 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -44,6 +44,7 @@ def __init__( semantic_vocab_size=10_000, max_input_semantic_length=256, semantic_rate_hz=49.9, + min_eos_p=None, **kwargs, ): """Class that holds a generation configuration for [`BarkSemanticModel`]. @@ -86,6 +87,8 @@ def __init__( Max length of semantic input vector. semantic_rate_hz (`float`, *optional*, defaults to 49.9): Semantic rate in Hertz. + min_eos_p (`float`, *optional*, defaults to None): + Minimum end of speech threshold. """ super().__init__( temperature=temperature, @@ -107,6 +110,7 @@ def __init__( self.semantic_vocab_size = semantic_vocab_size self.max_input_semantic_length = max_input_semantic_length self.semantic_rate_hz = semantic_rate_hz + self.min_eos_p = min_eos_p class BarkCoarseGenerationConfig(GenerationConfig): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index bdafb6347755..68a2f6412392 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -21,7 +21,7 @@ from torch import nn from torch.nn import functional as F -from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor +from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, BarkEarlyStoppingLogitsProcessor, SuppressTokensLogitsProcessor from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( @@ -793,13 +793,14 @@ def generate( ) suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) + early_stopping_logits_processor = BarkEarlyStoppingLogitsProcessor(semantic_generation_config) # pass input_ids in order to stay consistent with the transformers generate method even though it is not used # (except to get the input seq_len - that's why we keep the first 257 tokens) semantic_output = super().generate( torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device), input_embeds=input_embeds, - logits_processor=[suppress_tokens_logits_processor], + logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor], generation_config=semantic_generation_config, **kwargs, ) # size: 10048 From 9027782f9ad6ce11de6570b7a0fdbbfd12f7c638 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Sun, 8 Oct 2023 21:58:30 +0300 Subject: [PATCH 02/29] black formmated --- src/transformers/generation/logits_process.py | 4 ++-- src/transformers/models/bark/modeling_bark.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index f02507d5f08d..a213a8bc6203 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1702,12 +1702,12 @@ def __call__(self, input_ids, scores): unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits return out - + class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): r"""This processor will set every tokens' log probability other than the EOS token to `-inf` when the probabiliy of the EOS token id is superior to min_eos_p. - + Args: generate_config (`GenerateConfig`): The generate config used to generate the output. The following parameters are required: diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 68a2f6412392..1f27e41186ef 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -21,7 +21,11 @@ from torch import nn from torch.nn import functional as F -from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, BarkEarlyStoppingLogitsProcessor, SuppressTokensLogitsProcessor +from ...generation.logits_process import ( + AlternatingCodebooksLogitsProcessor, + BarkEarlyStoppingLogitsProcessor, + SuppressTokensLogitsProcessor, +) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( From a93b03cbb3971450cb5daf2d8c9e1c2e39868107 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Sun, 8 Oct 2023 23:32:20 +0300 Subject: [PATCH 03/29] indent --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index a213a8bc6203..e6aeda7c999e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1714,7 +1714,7 @@ class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): eos_token_id (`int`, *optional*, defaults to 50257): The id of the *end-of-sequence* token. min_eos_p (`float`, *optional*, defaults to None): - Minimum end of speech threshold. + Minimum end of speech threshold. """ def __init__(self, generate_config): From 1de69b8309b7b13962b4bc1f02469b351857bc06 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Sun, 8 Oct 2023 23:39:29 +0300 Subject: [PATCH 04/29] follow method signature --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e6aeda7c999e..17bac6379d40 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1722,6 +1722,6 @@ def __init__(self, generate_config): self.min_eos_p = generate_config.min_eos_p @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores[:, self.eos_token_id] = -float("inf") return scores From 0937f063c4894b5b4d571633cd9eb796e44655ac Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Mon, 9 Oct 2023 14:18:09 +0300 Subject: [PATCH 05/29] actual logic --- src/transformers/generation/logits_process.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 17bac6379d40..0f5e66970267 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1723,5 +1723,7 @@ def __init__(self, generate_config): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - scores[:, self.eos_token_id] = -float("inf") + for k in range(input_ids.shape[0]): + if scores[k, self.eos_token_id] > self.min_eos_p: + scores[k, : self.eos_token_id] = -float("inf") return scores From 05d1647654bd93c5187bb6eb90ea627667240b82 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Mon, 9 Oct 2023 17:01:26 +0300 Subject: [PATCH 06/29] check for None --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0f5e66970267..33bb45d9e558 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1724,6 +1724,6 @@ def __init__(self, generate_config): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: for k in range(input_ids.shape[0]): - if scores[k, self.eos_token_id] > self.min_eos_p: + if self.min_eos_p is not None and scores[k, self.eos_token_id] > self.min_eos_p: scores[k, : self.eos_token_id] = -float("inf") return scores From 88571d1e531f89246060577583439fad3e10c0f7 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Mon, 9 Oct 2023 19:40:55 +0300 Subject: [PATCH 07/29] address comments on docstrings and method signature --- src/transformers/generation/logits_process.py | 19 ++++++++----------- .../bark/generation_configuration_bark.py | 5 +++-- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 33bb45d9e558..8ffac9dc08f1 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1705,21 +1705,18 @@ def __call__(self, input_ids, scores): class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): - r"""This processor will set every tokens' log probability other than the EOS token to `-inf` - when the probabiliy of the EOS token id is superior to min_eos_p. + r"""This processor ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. Args: - generate_config (`GenerateConfig`): - The generate config used to generate the output. The following parameters are required: - eos_token_id (`int`, *optional*, defaults to 50257): - The id of the *end-of-sequence* token. - min_eos_p (`float`, *optional*, defaults to None): - Minimum end of speech threshold. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + min_eos_p (`float`, *optional*): + Minimum end of speech threshold. """ - def __init__(self, generate_config): - self.eos_token_id = generate_config.eos_token_id - self.min_eos_p = generate_config.min_eos_p + def __init__(self, eos_token_id: int, min_eos_p: float): + self.eos_token_id = eos_token_id + self.min_eos_p = min_eos_p @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index e2df243d2d97..044ad6bece96 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -87,8 +87,9 @@ def __init__( Max length of semantic input vector. semantic_rate_hz (`float`, *optional*, defaults to 49.9): Semantic rate in Hertz. - min_eos_p (`float`, *optional*, defaults to None): - Minimum end of speech threshold. + min_eos_p (`float`, *optional*): + Minimum threshold of the probability of the EOS token for it to be sampled. This is an early + stopping strategy to mitigate potential unwanted generations at the end of a prompt. """ super().__init__( temperature=temperature, From 528b967631dcf59fb493117afee19ae730b51b5b Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Mon, 9 Oct 2023 21:28:20 +0300 Subject: [PATCH 08/29] add unit test under `LogitsProcessorTest` wip --- src/transformers/generation/logits_process.py | 23 ++++++++++++++----- tests/generation/test_logits_process.py | 16 +++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 8ffac9dc08f1..bdd76d357a2e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1708,19 +1708,30 @@ class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): r"""This processor ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. Args: - eos_token_id (`int`): - The id of the *end-of-sequence* token. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. min_eos_p (`float`, *optional*): Minimum end of speech threshold. """ - def __init__(self, eos_token_id: int, min_eos_p: float): + def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id + if min_eos_p <= 0: + raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") self.min_eos_p = min_eos_p @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - for k in range(input_ids.shape[0]): - if self.min_eos_p is not None and scores[k, self.eos_token_id] > self.min_eos_p: - scores[k, : self.eos_token_id] = -float("inf") + if self.min_eos_p: + # create scores full of -inf except for the eos_token_id + early_stop_scores = torch.ones_like(scores) * -float("inf") + early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] + + # TODO: confirm threshold. mask requires work. + do_early_stop = scores[:, self.eos_token_id, None] > np.log(self.min_eos_p) + + scores = torch.where(do_early_stop, early_stop_scores, scores) + return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 32bd02936d1c..6e7cdc061172 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -54,6 +54,8 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, ) + from transformers.generation.logits_process import BarkEarlyStoppingLogitsProcessor + @require_torch class LogitsProcessorTest(unittest.TestCase): @@ -800,3 +802,17 @@ def lsm(x): self.assertAlmostEqual(out[0].item(), res[0].item()) self.assertAlmostEqual(out[1].item(), res[1].item()) self.assertAlmostEqual(out[2].item(), res[2].item()) + + def test_early_stop_processor(self): + input_ids = None + scores = self._get_uniform_logits(2, 4) + eos_token_id = 2 + min_eos_p = 0.1 ## some small float + + esp = BarkEarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + [float("-inf"), float("-inf"), 0.1, float("-inf")], + [float("-inf"), float("-inf"), 0.3, float("-inf")], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list) From dc33e369980f582306a28accf623aa43f7f6f15a Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 10 Oct 2023 10:09:07 +0300 Subject: [PATCH 09/29] unit test passing --- src/transformers/generation/logits_process.py | 7 +++---- tests/generation/test_logits_process.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index bdd76d357a2e..d54e0eabd435 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1729,9 +1729,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to early_stop_scores = torch.ones_like(scores) * -float("inf") early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] - # TODO: confirm threshold. mask requires work. - do_early_stop = scores[:, self.eos_token_id, None] > np.log(self.min_eos_p) - - scores = torch.where(do_early_stop, early_stop_scores, scores) + do_early_stop = torch.any(scores[:, self.eos_token_id] > np.log(self.min_eos_p)) + if do_early_stop: + scores = early_stop_scores return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 6e7cdc061172..5de7a034eee9 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -812,7 +812,7 @@ def test_early_stop_processor(self): esp = BarkEarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ - [float("-inf"), float("-inf"), 0.1, float("-inf")], - [float("-inf"), float("-inf"), 0.3, float("-inf")], + [float("-inf"), float("-inf"), scores[0][0], float("-inf")], + [float("-inf"), float("-inf"), scores[0][0], float("-inf")], ] self.assertListEqual(actual_scores.tolist(), expected_scores_list) From ce21bc5a4b28e95d1adcf49f43f3bf4179d040ba Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 10 Oct 2023 10:16:03 +0300 Subject: [PATCH 10/29] black formatted --- src/transformers/models/bark/generation_configuration_bark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 044ad6bece96..872b7e3ec8ab 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -88,7 +88,7 @@ def __init__( semantic_rate_hz (`float`, *optional*, defaults to 49.9): Semantic rate in Hertz. min_eos_p (`float`, *optional*): - Minimum threshold of the probability of the EOS token for it to be sampled. This is an early + Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping strategy to mitigate potential unwanted generations at the end of a prompt. """ super().__init__( From 45b3065e8ec637f169aeb1f0575614c07337d449 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 10 Oct 2023 12:10:52 +0300 Subject: [PATCH 11/29] condition per sample --- src/transformers/generation/logits_process.py | 5 ++--- tests/generation/test_logits_process.py | 7 ++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d54e0eabd435..c3987f6549db 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1729,8 +1729,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to early_stop_scores = torch.ones_like(scores) * -float("inf") early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] - do_early_stop = torch.any(scores[:, self.eos_token_id] > np.log(self.min_eos_p)) - if do_early_stop: - scores = early_stop_scores + do_early_stop = scores[:, self.eos_token_id] > np.log(self.min_eos_p) + scores = torch.where(do_early_stop, early_stop_scores, scores) return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 5de7a034eee9..01b51eec391c 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,7 +53,6 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) - from transformers.generation.logits_process import BarkEarlyStoppingLogitsProcessor @@ -805,14 +804,16 @@ def lsm(x): def test_early_stop_processor(self): input_ids = None - scores = self._get_uniform_logits(2, 4) eos_token_id = 2 min_eos_p = 0.1 ## some small float + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + esp = BarkEarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ - [float("-inf"), float("-inf"), scores[0][0], float("-inf")], + scores[0].tolist(), [float("-inf"), float("-inf"), scores[0][0], float("-inf")], ] self.assertListEqual(actual_scores.tolist(), expected_scores_list) From fa5d251687cc0da604371b20e307b40ae2a97fd2 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 10 Oct 2023 23:08:48 +0300 Subject: [PATCH 12/29] add to BarkModelIntegrationTests --- src/transformers/generation/logits_process.py | 2 +- src/transformers/models/bark/modeling_bark.py | 6 +++++- tests/models/bark/test_modeling_bark.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index c3987f6549db..3f9c3e5c82f4 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1718,7 +1718,7 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id - if min_eos_p <= 0: + if min_eos_p is not None and min_eos_p <= 0: raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") self.min_eos_p = min_eos_p diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 1f27e41186ef..69a0bc1f660e 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -797,7 +797,11 @@ def generate( ) suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) - early_stopping_logits_processor = BarkEarlyStoppingLogitsProcessor(semantic_generation_config) + + min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) + early_stopping_logits_processor = BarkEarlyStoppingLogitsProcessor( + eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p + ) # pass input_ids in order to stay consistent with the transformers generate method even though it is not used # (except to get the input seq_len - that's why we keep the first 257 tokens) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 3a5de30147e2..e1c387f884c5 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -74,6 +74,7 @@ def __init__( initializer_range=0.02, n_codes_total=8, # for BarkFineModel n_codes_given=1, # for BarkFineModel + min_eos_p=None, ): self.parent = parent self.batch_size = batch_size @@ -93,6 +94,7 @@ def __init__( self.bos_token_id = output_vocab_size - 1 self.eos_token_id = output_vocab_size - 1 self.pad_token_id = output_vocab_size - 1 + self.min_eos_p = min_eos_p self.n_codes_total = n_codes_total self.n_codes_given = n_codes_given @@ -1041,6 +1043,7 @@ def test_generate_end_to_end_with_sub_models_args(self): semantic_temperature=0.9, coarse_temperature=0.2, fine_temperature=0.1, + semantic_min_eos_p=0.1, ) @require_torch_gpu From b08346dc043a05e846c628af997a109fb4991756 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 10 Oct 2023 23:29:28 +0300 Subject: [PATCH 13/29] wip BarkSemanticModelTest --- tests/models/bark/test_modeling_bark.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index e1c387f884c5..6c274dc97029 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -74,7 +74,6 @@ def __init__( initializer_range=0.02, n_codes_total=8, # for BarkFineModel n_codes_given=1, # for BarkFineModel - min_eos_p=None, ): self.parent = parent self.batch_size = batch_size @@ -94,7 +93,6 @@ def __init__( self.bos_token_id = output_vocab_size - 1 self.eos_token_id = output_vocab_size - 1 self.pad_token_id = output_vocab_size - 1 - self.min_eos_p = min_eos_p self.n_codes_total = n_codes_total self.n_codes_given = n_codes_given @@ -582,6 +580,14 @@ def test_generate_fp16(self): model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + model = self.all_model_classes[0](config).eval().to(torch_device) + model.generate( + input_ids, + semantic_generation_config=BarkSemanticGenerationConfig( + max_input_semantic_length=int(len(input_ids) / 2), max_new_tokens=10, min_eos_p=0.1 + ), + ) + @require_torch class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): From 81f90ce556986e2d85b2b3334f1f4bb5ec5db17f Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Wed, 11 Oct 2023 19:47:47 +0300 Subject: [PATCH 14/29] rename and add to kwargs handling --- src/transformers/generation/logits_process.py | 2 +- src/transformers/models/bark/modeling_bark.py | 7 ++++--- tests/generation/test_logits_process.py | 4 ++-- tests/models/bark/test_modeling_bark.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 3f9c3e5c82f4..d42ab8482720 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1704,7 +1704,7 @@ def __call__(self, input_ids, scores): return out -class BarkEarlyStoppingLogitsProcessor(LogitsProcessor): +class EarlyStoppingLogitsProcessor(LogitsProcessor): r"""This processor ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. Args: diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 69a0bc1f660e..669962d42d1c 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, - BarkEarlyStoppingLogitsProcessor, + EarlyStoppingLogitsProcessor, SuppressTokensLogitsProcessor, ) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput @@ -799,7 +799,7 @@ def generate( suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) - early_stopping_logits_processor = BarkEarlyStoppingLogitsProcessor( + early_stopping_logits_processor = EarlyStoppingLogitsProcessor( eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p ) @@ -1564,7 +1564,8 @@ def generate( kwargs_semantic = { # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel - "attention_mask": kwargs.pop("attention_mask", None) + "attention_mask": kwargs.pop("attention_mask", None), + "min_eos_p": kwargs.pop("min_eos_p", None), } kwargs_coarse = {} kwargs_fine = {} diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 01b51eec391c..25a02f6e2107 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,7 +53,7 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) - from transformers.generation.logits_process import BarkEarlyStoppingLogitsProcessor + from transformers.generation.logits_process import EarlyStoppingLogitsProcessor @require_torch @@ -810,7 +810,7 @@ def test_early_stop_processor(self): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = BarkEarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = EarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(), diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 6c274dc97029..77c34622c4df 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -1049,7 +1049,7 @@ def test_generate_end_to_end_with_sub_models_args(self): semantic_temperature=0.9, coarse_temperature=0.2, fine_temperature=0.1, - semantic_min_eos_p=0.1, + min_eos_p=0.1, ) @require_torch_gpu From f9c1c8479f9cb4cecd2b39d657edca9faddfd202 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 12 Oct 2023 09:30:45 +0300 Subject: [PATCH 15/29] not add to BarkSemanticModelTest --- tests/models/bark/test_modeling_bark.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 77c34622c4df..5aaa745f9430 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -580,14 +580,6 @@ def test_generate_fp16(self): model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) - model = self.all_model_classes[0](config).eval().to(torch_device) - model.generate( - input_ids, - semantic_generation_config=BarkSemanticGenerationConfig( - max_input_semantic_length=int(len(input_ids) / 2), max_new_tokens=10, min_eos_p=0.1 - ), - ) - @require_torch class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): From a9645bbf96a1f6365d291cb67f16bc0d5e292c9c Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Fri, 13 Oct 2023 20:54:38 +0300 Subject: [PATCH 16/29] correct logic and assert last outputs tokens different in test --- src/transformers/generation/logits_process.py | 3 ++- tests/models/bark/test_modeling_bark.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d42ab8482720..59c0fd3dd4e6 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1724,12 +1724,13 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1) if self.min_eos_p: # create scores full of -inf except for the eos_token_id early_stop_scores = torch.ones_like(scores) * -float("inf") early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] - do_early_stop = scores[:, self.eos_token_id] > np.log(self.min_eos_p) + do_early_stop = logprobs[:, self.eos_token_id] > np.log(self.min_eos_p) scores = torch.where(do_early_stop, early_stop_scores, scores) return scores diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 5aaa745f9430..89fcc2ad4574 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -920,6 +920,27 @@ def test_generate_semantic(self): self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) + @slow + def test_generate_semantic_early_stop(self): + input_ids = self.inputs + + # fmt: off + # check first ids + expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,] + # fmt: on + + self.semantic_generation_config.min_eos_p = 0.05 + + with torch.no_grad(): + output_ids = self.model.semantic.generate( + **input_ids, + do_sample=True, + temperature=1.0, + semantic_generation_config=self.semantic_generation_config, + ) + + self.assertNotEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], [-1]) + @slow def test_generate_coarse(self): input_ids = self.inputs From 4c81b7dc405edda377744d670ac51b550125e479 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Fri, 13 Oct 2023 22:20:49 +0300 Subject: [PATCH 17/29] doc-builder style --- src/transformers/models/bark/generation_configuration_bark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 872b7e3ec8ab..11b484024b81 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -88,8 +88,8 @@ def __init__( semantic_rate_hz (`float`, *optional*, defaults to 49.9): Semantic rate in Hertz. min_eos_p (`float`, *optional*): - Minimum threshold of the probability of the EOS token for it to be sampled. This is an early - stopping strategy to mitigate potential unwanted generations at the end of a prompt. + Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping + strategy to mitigate potential unwanted generations at the end of a prompt. """ super().__init__( temperature=temperature, From ff7343b60c063d147fde844b2d234f6ee19d5249 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 19 Oct 2023 11:35:35 +0300 Subject: [PATCH 18/29] read from kwargs as well --- tests/models/bark/test_modeling_bark.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 89fcc2ad4574..ab781ec06bfe 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -929,8 +929,19 @@ def test_generate_semantic_early_stop(self): expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,] # fmt: on - self.semantic_generation_config.min_eos_p = 0.05 + # Should be aboe to read min_eos_p from kwargs + with torch.no_grad(): + output_ids = self.model.semantic.generate( + **input_ids, + do_sample=True, + temperature=1.0, + semantic_generation_config=self.semantic_generation_config, + min_eos_p=0.01, + ) + self.assertNotEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], [-1]) + # Should be able to read min_eos_p from the semantic generation config + self.semantic_generation_config.min_eos_p = 0.05 with torch.no_grad(): output_ids = self.model.semantic.generate( **input_ids, From 9e90232090301ef17724b129be9560804f53fb30 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 19 Oct 2023 13:49:40 +0300 Subject: [PATCH 19/29] assert len of with less than that of without --- tests/models/bark/test_modeling_bark.py | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index ab781ec06bfe..64912147238f 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -41,6 +41,8 @@ from ..encodec.test_modeling_encodec import EncodecModelTester +from scipy.io.wavfile import write as write_wav + if is_torch_available(): import torch @@ -917,8 +919,7 @@ def test_generate_semantic(self): temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) - - self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) + self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], expected_output_ids[-1]) @slow def test_generate_semantic_early_stop(self): @@ -938,10 +939,10 @@ def test_generate_semantic_early_stop(self): semantic_generation_config=self.semantic_generation_config, min_eos_p=0.01, ) - self.assertNotEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], [-1]) + self.assertLessEqual(len(output_ids.tolist()), len(expected_output_ids)) # Should be able to read min_eos_p from the semantic generation config - self.semantic_generation_config.min_eos_p = 0.05 + self.semantic_generation_config.min_eos_p = 0.01 with torch.no_grad(): output_ids = self.model.semantic.generate( **input_ids, @@ -949,8 +950,7 @@ def test_generate_semantic_early_stop(self): temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) - - self.assertNotEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], [-1]) + self.assertLessEqual(len(output_ids.tolist()), len(expected_output_ids)) @slow def test_generate_coarse(self): @@ -1054,27 +1054,27 @@ def test_generate_end_to_end_with_sub_models_args(self): input_ids = self.inputs with torch.no_grad(): - self.model.generate( - **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 - ) - self.model.generate( + # self.model.generate( + # **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 + # ) + output_ids_without_min_eos_p = self.model.generate( **input_ids, - do_sample=False, - temperature=1.0, + do_sample=True, + temperature=0.9, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3, ) - self.model.generate( + + output_ids_with_min_eos_p = self.model.generate( **input_ids, do_sample=True, - temperature=0.6, - penalty_alpha=0.6, - semantic_temperature=0.9, - coarse_temperature=0.2, - fine_temperature=0.1, + temperature=0.9, + coarse_temperature=0.7, + fine_temperature=0.3, min_eos_p=0.1, ) + self.assertLessEqual(len(output_ids_with_min_eos_p.tolist()), len(output_ids_without_min_eos_p.tolist())) @require_torch_gpu @slow From a37aac3b17ebfc3c59bcfcb3a0d9cd8b2383e3d5 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 19 Oct 2023 13:59:29 +0300 Subject: [PATCH 20/29] ruff --- tests/models/bark/test_modeling_bark.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 64912147238f..69ce7ed8ab9f 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -41,8 +41,6 @@ from ..encodec.test_modeling_encodec import EncodecModelTester -from scipy.io.wavfile import write as write_wav - if is_torch_available(): import torch From b0200b8c15c30b60e8d4bd1570cbf174de083ae7 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 19 Oct 2023 14:49:50 +0300 Subject: [PATCH 21/29] add back seed and test case --- tests/models/bark/test_modeling_bark.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 69ce7ed8ab9f..633fc12df7b6 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -1052,9 +1052,10 @@ def test_generate_end_to_end_with_sub_models_args(self): input_ids = self.inputs with torch.no_grad(): - # self.model.generate( - # **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 - # ) + torch.manual_seed(0) + self.model.generate( + **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 + ) output_ids_without_min_eos_p = self.model.generate( **input_ids, do_sample=True, From 7c743c646a44e3e3d3bfd5955e4269fd279ad976 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 19 Oct 2023 19:01:10 +0300 Subject: [PATCH 22/29] add original impl default suggestion --- src/transformers/models/bark/generation_configuration_bark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 11b484024b81..dd8ac221f05b 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -90,6 +90,7 @@ def __init__( min_eos_p (`float`, *optional*): Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping strategy to mitigate potential unwanted generations at the end of a prompt. + The original implementation suggests a default value of 0.2. """ super().__init__( temperature=temperature, From 28c12e92c576d1dde181bef7500535c6db061c64 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Fri, 20 Oct 2023 11:06:32 +0300 Subject: [PATCH 23/29] doc-builder --- src/transformers/models/bark/generation_configuration_bark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index dd8ac221f05b..7d7d98449d66 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -89,8 +89,8 @@ def __init__( Semantic rate in Hertz. min_eos_p (`float`, *optional*): Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping - strategy to mitigate potential unwanted generations at the end of a prompt. - The original implementation suggests a default value of 0.2. + strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation + suggests a default value of 0.2. """ super().__init__( temperature=temperature, From 542f41e290a9f618887484a772396a62c6f6c453 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Tue, 24 Oct 2023 23:03:39 +0300 Subject: [PATCH 24/29] rename and use softmax --- src/transformers/generation/logits_process.py | 8 ++++---- src/transformers/models/bark/modeling_bark.py | 4 ++-- tests/generation/test_logits_process.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 59c0fd3dd4e6..bbd348fde95e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1704,8 +1704,8 @@ def __call__(self, input_ids, scores): return out -class EarlyStoppingLogitsProcessor(LogitsProcessor): - r"""This processor ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. +class BarkEOSPrioritizerLogitsWarper(LogitsWarper): + r"""This warper ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. Args: eos_token_id (`Union[int, List[int]]`): @@ -1724,13 +1724,13 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1) + probs = torch.nn.functional.softmax(scores.float(), dim=-1) if self.min_eos_p: # create scores full of -inf except for the eos_token_id early_stop_scores = torch.ones_like(scores) * -float("inf") early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] - do_early_stop = logprobs[:, self.eos_token_id] > np.log(self.min_eos_p) + do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p scores = torch.where(do_early_stop, early_stop_scores, scores) return scores diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 669962d42d1c..5e78be86bf9d 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, - EarlyStoppingLogitsProcessor, + BarkEOSPrioritizerLogitsWarper, SuppressTokensLogitsProcessor, ) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput @@ -799,7 +799,7 @@ def generate( suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) - early_stopping_logits_processor = EarlyStoppingLogitsProcessor( + early_stopping_logits_processor = BarkEOSPrioritizerLogitsWarper( eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p ) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 25a02f6e2107..d32fb1753b1d 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,7 +53,7 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) - from transformers.generation.logits_process import EarlyStoppingLogitsProcessor + from transformers.generation.logits_process import BarkEOSPrioritizerLogitsWarper @require_torch @@ -810,7 +810,7 @@ def test_early_stop_processor(self): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = EarlyStoppingLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = BarkEOSPrioritizerLogitsWarper(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(), From e16309d8d004956d108a5c5a4148b7f8e61fd1fe Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Wed, 25 Oct 2023 15:37:57 +0300 Subject: [PATCH 25/29] switch back to LogitsProcessor and update docs wording --- src/transformers/generation/logits_process.py | 4 ++-- src/transformers/models/bark/modeling_bark.py | 4 ++-- tests/generation/test_logits_process.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index bbd348fde95e..514f775e9fd8 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1704,8 +1704,8 @@ def __call__(self, input_ids, scores): return out -class BarkEOSPrioritizerLogitsWarper(LogitsWarper): - r"""This warper ensures that the EOS token is sampled if its probability is greater than the `min_eos_p`. +class BarkEOSPrioritizerLogitsProcessor(LogitsProcessor): + r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. Args: eos_token_id (`Union[int, List[int]]`): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5e78be86bf9d..90fae555cc11 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, - BarkEOSPrioritizerLogitsWarper, + BarkEOSPrioritizerLogitsProcessor, SuppressTokensLogitsProcessor, ) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput @@ -799,7 +799,7 @@ def generate( suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) - early_stopping_logits_processor = BarkEOSPrioritizerLogitsWarper( + early_stopping_logits_processor = BarkEOSPrioritizerLogitsProcessor( eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p ) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index d32fb1753b1d..5e921f0d030a 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,7 +53,7 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) - from transformers.generation.logits_process import BarkEOSPrioritizerLogitsWarper + from transformers.generation.logits_process import BarkEOSPrioritizerLogitsProcessor @require_torch @@ -810,7 +810,7 @@ def test_early_stop_processor(self): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = BarkEOSPrioritizerLogitsWarper(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = BarkEOSPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(), From 0c75c1c873050a1505626973500548b028bd55a9 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Wed, 25 Oct 2023 20:46:46 +0300 Subject: [PATCH 26/29] camelCase and spelling and saving compute --- src/transformers/generation/logits_process.py | 4 ++-- src/transformers/models/bark/modeling_bark.py | 4 ++-- tests/generation/test_logits_process.py | 4 ++-- tests/models/bark/test_modeling_bark.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 514f775e9fd8..3787680f4976 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1704,7 +1704,7 @@ def __call__(self, input_ids, scores): return out -class BarkEOSPrioritizerLogitsProcessor(LogitsProcessor): +class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. Args: @@ -1724,8 +1724,8 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - probs = torch.nn.functional.softmax(scores.float(), dim=-1) if self.min_eos_p: + probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id early_stop_scores = torch.ones_like(scores) * -float("inf") early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 90fae555cc11..4c7a9e6da1bb 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, - BarkEOSPrioritizerLogitsProcessor, + BarkEosPrioritizerLogitsProcessor, SuppressTokensLogitsProcessor, ) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput @@ -799,7 +799,7 @@ def generate( suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) - early_stopping_logits_processor = BarkEOSPrioritizerLogitsProcessor( + early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor( eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p ) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 5e921f0d030a..15f5cf1e4f46 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,7 +53,7 @@ TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) - from transformers.generation.logits_process import BarkEOSPrioritizerLogitsProcessor + from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor @require_torch @@ -810,7 +810,7 @@ def test_early_stop_processor(self): scores = self._get_uniform_logits(2, 4) scores[0][eos_token_id] = -6 ## less than log(min_eos_p) - esp = BarkEOSPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) actual_scores = esp(input_ids, scores) expected_scores_list = [ scores[0].tolist(), diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 633fc12df7b6..ed3d4416c213 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -928,7 +928,7 @@ def test_generate_semantic_early_stop(self): expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,] # fmt: on - # Should be aboe to read min_eos_p from kwargs + # Should be able to read min_eos_p from kwargs with torch.no_grad(): output_ids = self.model.semantic.generate( **input_ids, From d88acd822da11f02395d5de287320aebc0eb545e Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 26 Oct 2023 01:06:55 +0300 Subject: [PATCH 27/29] assert strictly less than --- tests/models/bark/test_modeling_bark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index ed3d4416c213..f7b9b9ffe6fb 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -917,7 +917,7 @@ def test_generate_semantic(self): temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) - self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist()[-1], expected_output_ids[-1]) + self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) @slow def test_generate_semantic_early_stop(self): @@ -937,7 +937,7 @@ def test_generate_semantic_early_stop(self): semantic_generation_config=self.semantic_generation_config, min_eos_p=0.01, ) - self.assertLessEqual(len(output_ids.tolist()), len(expected_output_ids)) + self.assertLess(len(output_ids.tolist()), len(expected_output_ids)) # Should be able to read min_eos_p from the semantic generation config self.semantic_generation_config.min_eos_p = 0.01 @@ -948,7 +948,7 @@ def test_generate_semantic_early_stop(self): temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) - self.assertLessEqual(len(output_ids.tolist()), len(expected_output_ids)) + self.assertLess(len(output_ids.tolist()), len(expected_output_ids)) @slow def test_generate_coarse(self): From eb9b2f0addd03ede543e621acb36fb0051d747d9 Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Thu, 26 Oct 2023 10:43:38 +0300 Subject: [PATCH 28/29] assert less than --- tests/models/bark/test_modeling_bark.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index f7b9b9ffe6fb..91b7cbf80950 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -1073,7 +1073,9 @@ def test_generate_end_to_end_with_sub_models_args(self): fine_temperature=0.3, min_eos_p=0.1, ) - self.assertLessEqual(len(output_ids_with_min_eos_p.tolist()), len(output_ids_without_min_eos_p.tolist())) + self.assertLess( + len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()) + ) @require_torch_gpu @slow From 9e835472a43a7968dfd7bb3ee07d173be4d19b2b Mon Sep 17 00:00:00 2001 From: Isaac Chung Date: Fri, 27 Oct 2023 10:40:25 +0300 Subject: [PATCH 29/29] expand test_generate_semantic_early_stop instead --- tests/models/bark/test_modeling_bark.py | 32 ++++++++++++++++++------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 91b7cbf80950..d80ee24a1610 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -922,6 +922,7 @@ def test_generate_semantic(self): @slow def test_generate_semantic_early_stop(self): input_ids = self.inputs + min_eos_p = 0.01 # fmt: off # check first ids @@ -930,25 +931,38 @@ def test_generate_semantic_early_stop(self): # Should be able to read min_eos_p from kwargs with torch.no_grad(): - output_ids = self.model.semantic.generate( + torch.manual_seed(0) + output_ids_without_min_eos_p = self.model.semantic.generate( **input_ids, - do_sample=True, - temperature=1.0, + do_sample=False, + temperature=0.9, semantic_generation_config=self.semantic_generation_config, - min_eos_p=0.01, ) - self.assertLess(len(output_ids.tolist()), len(expected_output_ids)) + torch.manual_seed(0) + output_ids_kwargs = self.model.semantic.generate( + **input_ids, + do_sample=False, + temperature=0.9, + semantic_generation_config=self.semantic_generation_config, + min_eos_p=min_eos_p, + ) + self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids) + self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())) # Should be able to read min_eos_p from the semantic generation config - self.semantic_generation_config.min_eos_p = 0.01 + self.semantic_generation_config.min_eos_p = min_eos_p with torch.no_grad(): + torch.manual_seed(0) output_ids = self.model.semantic.generate( **input_ids, - do_sample=True, - temperature=1.0, + do_sample=False, + temperature=0.9, semantic_generation_config=self.semantic_generation_config, ) - self.assertLess(len(output_ids.tolist()), len(expected_output_ids)) + + self.assertEqual(output_ids.shape, output_ids_kwargs.shape) + self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())) + self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) @slow def test_generate_coarse(self):