Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
aa7e01a
move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new t…
keyboardAnt Nov 28, 2024
f6b7f20
refactor
keyboardAnt Nov 28, 2024
0ded37c
NOTHING. add space to rerun github actions tests
keyboardAnt Nov 28, 2024
d48b69b
remove it...
keyboardAnt Nov 28, 2024
b47e33a
`UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 16, 2024
8a99129
Use `UniversalSpeculativeDecodingGenerator` when `generation_config.d…
keyboardAnt Nov 16, 2024
4649bd2
assistant tokenizes only the target's new suffix
keyboardAnt Nov 16, 2024
f199c94
formatting
keyboardAnt Nov 16, 2024
19c0057
fix code
jmamou Nov 21, 2024
acf5a4b
fix code
jmamou Nov 24, 2024
3712117
formatting
keyboardAnt Nov 24, 2024
63f2f46
add `TestGenerateWithDifferentModels`
keyboardAnt Nov 24, 2024
6ac33f1
`TestGenerateWithDifferentModels` parameterize on `do_sample`
keyboardAnt Nov 24, 2024
6938311
`AssistantVocabMapping` & `AssistantVocabMappingCache`
keyboardAnt Nov 24, 2024
5a0db3b
formatting
keyboardAnt Nov 24, 2024
92f8ad3
`AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_l…
keyboardAnt Nov 24, 2024
7c8708e
improve `_get_assistant_to_target_input_ids` & formatting
keyboardAnt Nov 24, 2024
880d0ae
renaming
keyboardAnt Nov 24, 2024
d9b5e74
WIP: debugging `min_new_tokens`
keyboardAnt Nov 25, 2024
25974d5
fix get_target_ids
jmamou Nov 25, 2024
b8636ab
`UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 16, 2024
1ef46b7
assistant tokenizes only the target's new suffix
keyboardAnt Nov 16, 2024
f8e94eb
formatting
keyboardAnt Nov 16, 2024
439db84
fix code
jmamou Nov 21, 2024
643901d
fix code
jmamou Nov 24, 2024
77097ff
formatting
keyboardAnt Nov 24, 2024
d08b4f0
`TestGenerateWithDifferentModels` parameterize on `do_sample`
keyboardAnt Nov 24, 2024
f242dc1
`AssistantVocabMapping` & `AssistantVocabMappingCache`
keyboardAnt Nov 24, 2024
ede1176
formatting
keyboardAnt Nov 24, 2024
511ee96
`AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_l…
keyboardAnt Nov 24, 2024
5e47945
improve `_get_assistant_to_target_input_ids` & formatting
keyboardAnt Nov 24, 2024
25a4349
renaming
keyboardAnt Nov 24, 2024
95fe744
WIP: debugging `min_new_tokens`
keyboardAnt Nov 25, 2024
0ad88b2
fix get_target_ids
jmamou Nov 25, 2024
bc5fa61
fix device issue
jmamou Nov 25, 2024
41a5670
fix get_assistant_input_ids
jmamou Nov 25, 2024
44f7ba7
add `TestAssistedCandidateGeneratorDifferentTokenizers`
keyboardAnt Nov 26, 2024
57aafcc
formatting
keyboardAnt Nov 26, 2024
6f95c33
`AssistantVocabTranslatorCache` refactor & tests
keyboardAnt Nov 26, 2024
078f763
revert changes in `src/transformers/generation/logits_process.py`
keyboardAnt Nov 26, 2024
faac2fc
refactor `AssistedCandidateGenerator`
keyboardAnt Nov 26, 2024
76a2dd3
refactor `AssistedCandidateGeneratorDifferentTokenizers`
keyboardAnt Nov 26, 2024
43e96e7
formatting
keyboardAnt Nov 26, 2024
e63cb9d
refactor `UniversalSpeculativeDecodingGenerator`
keyboardAnt Nov 26, 2024
8aa6020
fix negative value for max_new_tokens
jmamou Nov 26, 2024
2169973
fix generation length target + attention_mask vs. assistant + attent
jmamou Nov 26, 2024
c6da827
fix device
jmamou Nov 26, 2024
2cf9e8e
fix negative max_new_tokens bug
jmamou Nov 27, 2024
a1c0d05
fix UAG
jmamou Nov 28, 2024
d830091
minor
jmamou Nov 28, 2024
19d0cce
formatting
keyboardAnt Nov 28, 2024
5b8217d
`AssistedCandidateGeneratorDifferentTokenizers` `lookbehind`s init
keyboardAnt Nov 28, 2024
9b0126a
resolve conflict & formatting
keyboardAnt Nov 30, 2024
578d0b3
rerun CI tests
keyboardAnt Nov 30, 2024
7db2695
remove space...
keyboardAnt Nov 30, 2024
fb69900
remove old code
keyboardAnt Dec 3, 2024
e40c775
fix candidate_input_ids device
jmamou Dec 4, 2024
b5ce873
minor
jmamou Dec 4, 2024
bfccdea
Merge pull request #4 from keyboardAnt/fix_device
keyboardAnt Dec 5, 2024
d34d7ea
formatting
keyboardAnt Dec 5, 2024
9d4d9f9
Fix prepare + apply (#7)
jmamou Dec 17, 2024
4e92e9c
Add unittests for Universal Assisted generation
gauravj14 Dec 12, 2024
3fe2d31
Merge branch 'main' into usd
jmamou Dec 18, 2024
a350b1c
fix style
jmamou Dec 18, 2024
e047adf
update tests
jmamou Dec 18, 2024
011f595
Remove unused import and fix `test_speculation_depth` test
gauravjain14 Dec 17, 2024
2652490
exclude special and reserved tokens from tokenizer for UAG
gauravjain14 Dec 18, 2024
701edbb
mv `test_universal_assisted_generation.py` to `generation/test_candid…
gauravjain14 Dec 19, 2024
7088978
Merge pull request #8 from keyboardAnt/unit_tests_usd
gauravjain14 Dec 19, 2024
3b89341
Remove unused imports and fix style using `make style` (#9)
gauravjain14 Dec 20, 2024
e43dba8
formatting
keyboardAnt Dec 21, 2024
a529795
Swap gated `meta-llama/llama-3.2` with `allenai/llama` (#10)
gauravjain14 Dec 21, 2024
9025751
Merge branch 'main' into usd
keyboardAnt Jan 6, 2025
25cd5da
Fix space sign disagreement (#12)
jmamou Jan 9, 2025
77edae2
Default values for some fields of assistant to target translator (#11)
jmamou Jan 9, 2025
a2a2882
Update candidate_generator.py (#15)
jmamou Jan 12, 2025
a556947
BUG fix in _prepare_assistant_input_ids (#14)
jmamou Jan 12, 2025
407d898
typo (`target_to_assistant_input_ids`)
keyboardAnt Jan 13, 2025
a24b193
formatting
keyboardAnt Jan 13, 2025
1afdaa3
merge upstream/main
keyboardAnt Jan 15, 2025
88f6877
Merge branch 'main' into usd
keyboardAnt Jan 16, 2025
4e3660a
Fix minor review comments (#16)
gauravjain14 Jan 27, 2025
c162c88
Fix: `token_ids.to(torch.int64)` (#18)
keyboardAnt Jan 28, 2025
d0798a0
fix dtype
keyboardAnt Jan 28, 2025
d18d090
`assistant_input_ids.to(dtype=torch.long)`
keyboardAnt Jan 28, 2025
ae2f16f
Remove unused import from test_candidate_generator.py
gauravjain14 Jan 28, 2025
02dba31
Merge branch 'main' of https://github.com/keyboardAnt/transformers in…
keyboardAnt Feb 15, 2025
49a228f
resolve pr comments (#19)
keyboardAnt Feb 15, 2025
7f76fec
formatting
keyboardAnt Feb 15, 2025
32335a5
Merge branch 'main' into usd
keyboardAnt Feb 23, 2025
1a79647
Merge branch 'main' into usd
jmamou Feb 24, 2025
78a2a2c
Merge branch 'main' into usd
jmamou Feb 25, 2025
751a099
Fix Joao's comments (#21)
jmamou Feb 25, 2025
bfb636d
Merge branch 'main' into usd
jmamou Feb 25, 2025
00e325d
Merge branch 'main' into usd
jmamou Feb 25, 2025
8a39f5b
fix style (#23)
jmamou Feb 25, 2025
64c95fe
Merge branch 'main' into usd
jmamou Feb 25, 2025
7661fc9
Move atm (#24)
jmamou Feb 25, 2025
503ece9
fix logit_processor
jmamou Feb 25, 2025
fb7187d
add atm_translator test
jmamou Feb 25, 2025
dedcf98
refactor test
jmamou Feb 25, 2025
7e3f3dc
Merge branch 'main' into usd
jmamou Feb 25, 2025
c9fc5a6
Merge branch 'main' into usd
jmamou Feb 25, 2025
94e8a31
Merge branch 'main' into usd
jmamou Feb 25, 2025
4e23470
remove threading from test
jmamou Feb 26, 2025
eae175c
Merge branch 'main' into usd
jmamou Feb 26, 2025
6784931
add require_torch in tests
jmamou Feb 26, 2025
d20f07b
Merge branch 'main' into usd
jmamou Feb 26, 2025
be79a15
move AssistantVocabTranslatorCache + add tests
jmamou Feb 26, 2025
9cb0a3a
Merge branch 'main' into usd
jmamou Feb 26, 2025
b0e7a16
ruff fix
jmamou Feb 26, 2025
683bbee
Merge branch 'main' into usd
jmamou Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 291 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import copy
import weakref
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

import numpy as np
Expand All @@ -27,7 +28,7 @@

from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor


if TYPE_CHECKING:
Expand Down Expand Up @@ -283,18 +284,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens

def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
def _update_past_and_masks(
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])

return has_past_key_values

def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
Expand Down Expand Up @@ -608,6 +612,290 @@ def _process_assistant_outputs(
return new_target_ids


class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
(https://www.arxiv.org/abs/2502.05202).
It maintains mappings between the two vocabularies and handles token/logit conversion.

Args:
target_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
assistant_model_device (`str`, defaults to "cpu"):
The device where the assistant model is located. Used for placing tensors.
target_vocab_size (`int`, *optional*):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
"""

FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.

def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
assistant_model_device: str = "cpu",
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device: str = assistant_model_device
self.target_vocab_size: int = target_vocab_size
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
if len(self._suppress_input_ids) > 0:
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)

def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
assistant_vocab = self._assistant_tokenizer.get_vocab()

space_str = " "
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(target_space_ids) > 0:
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]

assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(assistant_space_ids) > 0:
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]

if target_space_sign != assistant_space_sign:
# If the assistant tokenizer has a different space sign than the target tokenizer,
# we need to replace the assistant space sign with the target space sign in the assistant_vocab.
assistant_vocab = {
(
tok.replace(assistant_space_sign, target_space_sign, 1)
if tok.startswith(assistant_space_sign)
else tok
): idx
for tok, idx in assistant_vocab.items()
}

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
target_to_assistant_input_ids: Dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids

def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]

def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
) -> torch.LongTensor:
"""
Return the target candidate ids that correspond to the assistant candidate ids.
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
"""

num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)

def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
"""
Return the target logits that correspond to the assistant logits.
"""

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

return target_logits


class AssistantVocabTranslatorCache:
"""
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
pre-processing time, and this cache allows us to avoid recomputing them.
"""

_cache = weakref.WeakKeyDictionary()

@classmethod
def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
assistant_model_device: str = "cpu",
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
assistant_dict = weakref.WeakKeyDictionary()
cls._cache[target_tokenizer] = assistant_dict

mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
)
assistant_dict[assistant_tokenizer] = mapping

return mapping

@classmethod
def cleanup(cls):
"""
Clean up dead references in the cache.
This removes entries where either the target_tokenizer or assistant_tokenizer
has been garbage collected.
"""
# Remove entries from the outer cache where the target_tokenizer is no longer alive
dead_keys = [key for key in cls._cache if key is None]
for key in dead_keys:
del cls._cache[key]

# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
for assistant_dict in cls._cache.values():
dead_keys = [key for key in assistant_dict if key is None]
for key in dead_keys:
del assistant_dict[key]


class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
"""
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
for the assistant and main models. This class generates candidates through the use of a smaller model.
"""

def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
atm_translator: AssistantToTargetTranslator,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Initialize translator before parent class
self._atm_translator = atm_translator
super().__init__(
input_ids,
assistant_model,
target_tokenizer,
assistant_tokenizer,
generation_config,
model_kwargs,
inputs_tensor,
logits_processor,
)
# Track sequence lengths and previous assistant IDs
self._target_seq_len_with_candidates: int = 0
self._prev_assistant_ids: Optional[torch.LongTensor] = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Simplified version of get_candidates that uses the translator cache for token conversion.
"""
target_input_ids = input_ids.to(self.assistant_model.device)
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)

if max_new_tokens == 0:
return input_ids, None

self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)

# Ensure scores are returned
generation_args["generation_config"].output_scores = True
generation_args["generation_config"].return_dict_in_generate = True

# Generate and process outputs using translator
if self._atm_translator.logits_processors is not None:
generation_args["logits_processor"] = self._atm_translator.logits_processors
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)

# Use translator to convert tokens and logits
target_candidate_ids = self._atm_translator.get_target_ids(
assistant_input_ids, target_input_ids, self._prev_assistant_ids
)
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)

return target_candidate_ids, target_candidate_logits

def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
if self._prev_assistant_ids is None:
# Prepare attention mask for the first generation.
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)

def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Simplified token conversion that only processes new tokens.
"""
# Calculate new tokens since last call
target_seq_len = target_input_ids.shape[-1]
if self._target_seq_len_with_candidates == 0:
new_token_count = target_seq_len
else:
new_token_count = 1
target_new_ids = target_input_ids[:, -new_token_count:]

# Convert the new tokens
assistant_new_ids = None
if self._target_seq_len_with_candidates > 0:
# we have only one new token and we can directly convert it
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
if assistant_new_ids is None:
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(
target_new_text, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to(self.assistant_model.device)
else:
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)

# Update or initialize assistant IDs
if self._prev_assistant_ids is None:
assistant_input_ids = assistant_new_ids
else:
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
if tokens_to_remove > 0:
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)

return assistant_input_ids, len(assistant_new_ids[0])


class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
Expand Down
Loading