Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(self, **kwargs):
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class TypicalLogitsWarper(LogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):

self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)

# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores


def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
Expand Down Expand Up @@ -619,7 +620,12 @@ def _reorder_cache(self, past, beam_idx):
)

def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
self,
top_k: int = None,
top_p: float = None,
typical_p: float = None,
temperature: float = None,
num_beams: int = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
Expand All @@ -629,6 +635,7 @@ def _get_logits_warper(
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
typical_p = typical_p if typical_p is not None else self.config.typical_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
Expand All @@ -641,6 +648,8 @@ def _get_logits_warper(
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typical_p should not be equal to 1.0 IMO cc @harubaru

warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
return warpers

def _get_logits_processor(
Expand Down Expand Up @@ -810,6 +819,7 @@ def generate(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
Expand Down Expand Up @@ -1188,7 +1198,7 @@ def generate(
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
)

# 11. expand input_ids with `num_return_sequences` additional sequences per batch
Expand Down Expand Up @@ -1250,7 +1260,7 @@ def generate(
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
)

if stopping_criteria.max_length is None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"temperature": 2.0,
"top_k": 10,
"top_p": 0.7,
"typical_p": 0.2,
"repetition_penalty": 0.8,
"length_penalty": 0.8,
"no_repeat_ngram_size": 5,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)


Expand Down Expand Up @@ -191,6 +192,51 @@ def test_top_p_dist_warper(self):
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])

def test_typical_dist_warper(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice test!

input_ids = None
vocab_size = 10
batch_size = 2

# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
)

typical_warp = TypicalLogitsWarper(0.5)
filtered_dist = torch.exp(typical_warp(input_ids, dist))

# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))

# check special cases
length = 5

logits = self._get_uniform_logits(batch_size=batch_size, length=length)
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)

scores = typical_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])

# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)

# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0

# make sure at least 2 tokens are kept
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = typical_warp(input_ids, ramp_logits)

# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])

def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
Expand Down