Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c621e66
add early stopping logits processor
isaac-chung Oct 8, 2023
9027782
black formmated
isaac-chung Oct 8, 2023
a93b03c
indent
isaac-chung Oct 8, 2023
1de69b8
follow method signature
isaac-chung Oct 8, 2023
0937f06
actual logic
isaac-chung Oct 9, 2023
05d1647
check for None
isaac-chung Oct 9, 2023
88571d1
address comments on docstrings and method signature
isaac-chung Oct 9, 2023
528b967
add unit test under `LogitsProcessorTest` wip
isaac-chung Oct 9, 2023
dc33e36
unit test passing
isaac-chung Oct 10, 2023
ce21bc5
black formatted
isaac-chung Oct 10, 2023
45b3065
condition per sample
isaac-chung Oct 10, 2023
fa5d251
add to BarkModelIntegrationTests
isaac-chung Oct 10, 2023
b08346d
wip BarkSemanticModelTest
isaac-chung Oct 10, 2023
81f90ce
rename and add to kwargs handling
isaac-chung Oct 11, 2023
f9c1c84
not add to BarkSemanticModelTest
isaac-chung Oct 12, 2023
a9645bb
correct logic and assert last outputs tokens different in test
isaac-chung Oct 13, 2023
4c81b7d
doc-builder style
isaac-chung Oct 13, 2023
ff7343b
read from kwargs as well
isaac-chung Oct 19, 2023
9e90232
assert len of with less than that of without
isaac-chung Oct 19, 2023
a37aac3
ruff
isaac-chung Oct 19, 2023
b0200b8
add back seed and test case
isaac-chung Oct 19, 2023
7c743c6
add original impl default suggestion
isaac-chung Oct 19, 2023
28c12e9
doc-builder
isaac-chung Oct 20, 2023
542f41e
rename and use softmax
isaac-chung Oct 24, 2023
e16309d
switch back to LogitsProcessor and update docs wording
isaac-chung Oct 25, 2023
0c75c1c
camelCase and spelling and saving compute
isaac-chung Oct 25, 2023
d88acd8
assert strictly less than
isaac-chung Oct 25, 2023
eb9b2f0
assert less than
isaac-chung Oct 26, 2023
9e83547
expand test_generate_semantic_early_stop instead
isaac-chung Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,3 +1702,35 @@ def __call__(self, input_ids, scores):
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out


class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""

def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
early_stop_scores = torch.ones_like(scores) * -float("inf")
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]

do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
scores = torch.where(do_early_stop, early_stop_scores, scores)

return scores
6 changes: 6 additions & 0 deletions src/transformers/models/bark/generation_configuration_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
semantic_vocab_size=10_000,
max_input_semantic_length=256,
semantic_rate_hz=49.9,
min_eos_p=None,
**kwargs,
):
"""Class that holds a generation configuration for [`BarkSemanticModel`].
Expand Down Expand Up @@ -86,6 +87,10 @@ def __init__(
Max length of semantic input vector.
semantic_rate_hz (`float`, *optional*, defaults to 49.9):
Semantic rate in Hertz.
min_eos_p (`float`, *optional*):
Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping
strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation
suggests a default value of 0.2.
"""
super().__init__(
temperature=temperature,
Expand All @@ -107,6 +112,7 @@ def __init__(
self.semantic_vocab_size = semantic_vocab_size
self.max_input_semantic_length = max_input_semantic_length
self.semantic_rate_hz = semantic_rate_hz
self.min_eos_p = min_eos_p


class BarkCoarseGenerationConfig(GenerationConfig):
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from torch import nn
from torch.nn import functional as F

from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor
from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
BarkEosPrioritizerLogitsProcessor,
SuppressTokensLogitsProcessor,
)
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
Expand Down Expand Up @@ -794,12 +798,17 @@ def generate(

suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress)

min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p
)

# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
# (except to get the input seq_len - that's why we keep the first 257 tokens)
semantic_output = super().generate(
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
input_embeds=input_embeds,
logits_processor=[suppress_tokens_logits_processor],
logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
generation_config=semantic_generation_config,
**kwargs,
) # size: 10048
Expand Down Expand Up @@ -1555,7 +1564,8 @@ def generate(

kwargs_semantic = {
# if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
"attention_mask": kwargs.pop("attention_mask", None)
"attention_mask": kwargs.pop("attention_mask", None),
"min_eos_p": kwargs.pop("min_eos_p", None),
}
kwargs_coarse = {}
kwargs_fine = {}
Expand Down
17 changes: 17 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor


@require_torch
Expand Down Expand Up @@ -800,3 +801,19 @@ def lsm(x):
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())

def test_early_stop_processor(self):
input_ids = None
eos_token_id = 2
min_eos_p = 0.1 ## some small float

scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)

esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
66 changes: 57 additions & 9 deletions tests/models/bark/test_modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,51 @@ def test_generate_semantic(self):
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)

@slow
def test_generate_semantic_early_stop(self):
input_ids = self.inputs
min_eos_p = 0.01

# fmt: off
# check first ids
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
# fmt: on

# Should be able to read min_eos_p from kwargs
with torch.no_grad():
torch.manual_seed(0)
output_ids_without_min_eos_p = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)
torch.manual_seed(0)
output_ids_kwargs = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
min_eos_p=min_eos_p,
)
self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids)
self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))

# Should be able to read min_eos_p from the semantic generation config
self.semantic_generation_config.min_eos_p = min_eos_p
with torch.no_grad():
torch.manual_seed(0)
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)

self.assertEqual(output_ids.shape, output_ids_kwargs.shape)
self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)

@slow
Expand Down Expand Up @@ -1022,26 +1066,30 @@ def test_generate_end_to_end_with_sub_models_args(self):
input_ids = self.inputs

with torch.no_grad():
torch.manual_seed(0)
self.model.generate(
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
)
self.model.generate(
output_ids_without_min_eos_p = self.model.generate(
**input_ids,
do_sample=False,
temperature=1.0,
do_sample=True,
temperature=0.9,
coarse_do_sample=True,
coarse_temperature=0.7,
fine_temperature=0.3,
)
self.model.generate(

output_ids_with_min_eos_p = self.model.generate(
**input_ids,
do_sample=True,
temperature=0.6,
penalty_alpha=0.6,
semantic_temperature=0.9,
coarse_temperature=0.2,
fine_temperature=0.1,
temperature=0.9,
coarse_temperature=0.7,
fine_temperature=0.3,
min_eos_p=0.1,
)
self.assertLess(
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
)

@require_torch_gpu
@slow
Expand Down