diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 670fb7856020..53910f3e8aaf 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -280,6 +280,7 @@ def __init__(self, **kwargs): self.temperature = kwargs.pop("temperature", 1.0) self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) + self.typical_p = kwargs.pop("typical_p", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 3e8a2f3cff9e..ad79273502e9 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -239,6 +239,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +class TypicalLogitsWarper(LogitsWarper): + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index e822609ecc0c..c225436d1b7b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -40,6 +40,7 @@ TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, ) from .generation_stopping_criteria import ( MaxLengthCriteria, @@ -619,7 +620,12 @@ def _reorder_cache(self, past, beam_idx): ) def _get_logits_warper( - self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None + self, + top_k: int = None, + top_p: float = None, + typical_p: float = None, + temperature: float = None, + num_beams: int = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -629,6 +635,7 @@ def _get_logits_warper( # init warp parameters top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p + typical_p = typical_p if typical_p is not None else self.config.typical_p temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() @@ -641,6 +648,8 @@ def _get_logits_warper( warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) return warpers def _get_logits_processor( @@ -810,6 +819,7 @@ def generate( temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, @@ -1188,7 +1198,7 @@ def generate( elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams ) # 11. expand input_ids with `num_return_sequences` additional sequences per batch @@ -1250,7 +1260,7 @@ def generate( elif is_beam_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams ) if stopping_criteria.max_length is None: diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 7a84d84c7976..a94d4f1d3f6f 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -51,6 +51,7 @@ "temperature": 2.0, "top_k": 10, "top_p": 0.7, + "typical_p": 0.2, "repetition_penalty": 0.8, "length_penalty": 0.8, "no_repeat_ngram_size": 5, diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index e07fd3066e2e..7b6c78b286ea 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -41,6 +41,7 @@ TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, ) @@ -191,6 +192,51 @@ def test_top_p_dist_warper(self): # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) + def test_typical_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = torch.log( + torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float) + ) + + typical_warp = TypicalLogitsWarper(0.5) + filtered_dist = torch.exp(typical_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= 0.7 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = torch.tensor( + [[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float + ) + self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check special cases + length = 5 + + logits = self._get_uniform_logits(batch_size=batch_size, length=length) + typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3) + + scores = typical_warp_safety_check(input_ids, logits) + # uniform dist is not changed + self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0]) + + # check edge cases with negative and extreme logits + ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( + batch_size, 1 + ) - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = typical_warp(input_ids, ramp_logits) + + # first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) + def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2