diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 2c4ab9c2a974..ebb9a18d559f 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import weakref from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import numpy as np @@ -27,7 +28,7 @@ from ..cache_utils import DynamicCache from ..pytorch_utils import isin_mps_friendly -from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor +from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor if TYPE_CHECKING: @@ -283,18 +284,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) return min_new_tokens, max_new_tokens - def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: + def _update_past_and_masks( + self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 + ) -> bool: """Update past key values and attention masks for subsequent generation rounds.""" has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens ) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) + return has_past_key_values def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: @@ -608,6 +612,290 @@ def _process_assistant_outputs( return new_target_ids +class AssistantToTargetTranslator: + """ + Translates token ids and logits between assistant and target model vocabularies. This class is used to handle + vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, + as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" + (https://www.arxiv.org/abs/2502.05202). + It maintains mappings between the two vocabularies and handles token/logit conversion. + + Args: + target_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the target (main) model. + assistant_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the assistant model. + assistant_model_device (`str`, defaults to "cpu"): + The device where the assistant model is located. Used for placing tensors. + target_vocab_size (`int`, *optional*): + The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. + """ + + FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. + SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. + + def __init__( + self, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + assistant_model_device: str = "cpu", + ): + self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer + self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer + self._assistant_model_device: str = assistant_model_device + self.target_vocab_size: int = target_vocab_size + self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( + self._get_assistant_to_target_input_ids() + ) + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() + self.logits_processors: Optional[LogitsProcessorList] = None + if len(self._suppress_input_ids) > 0: + # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) + + def _get_assistant_to_target_input_ids(self): + target_vocab = self._target_tokenizer.get_vocab() + assistant_vocab = self._assistant_tokenizer.get_vocab() + + space_str = " " + target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(target_space_ids) > 0: + target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] + + assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(assistant_space_ids) > 0: + assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] + + if target_space_sign != assistant_space_sign: + # If the assistant tokenizer has a different space sign than the target tokenizer, + # we need to replace the assistant space sign with the target space sign in the assistant_vocab. + assistant_vocab = { + ( + tok.replace(assistant_space_sign, target_space_sign, 1) + if tok.startswith(assistant_space_sign) + else tok + ): idx + for tok, idx in assistant_vocab.items() + } + + max_assistant_index = max(assistant_vocab.values()) + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) + target_to_assistant_input_ids: Dict[int, int] = {} + for tok, assistant_id in assistant_vocab.items(): + target_id = target_vocab.get(tok) + if target_id is not None: + assistant_to_target_input_ids[assistant_id] = target_id + target_to_assistant_input_ids[target_id] = assistant_id + return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids + + def _get_suppress_input_ids(self) -> list[int]: + """ + Get the input ids that are in the assistant vocab but not in the target vocab. + """ + return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] + + def get_target_ids( + self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor + ) -> torch.LongTensor: + """ + Return the target candidate ids that correspond to the assistant candidate ids. + Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. + Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. + """ + + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + if num_new_tokens == 0: + return target_input_ids + else: + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) + + def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: + """ + Return the target logits that correspond to the assistant logits. + """ + + target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) + target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) + # Mask for valid indices + assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID + # Exclude invalid indices + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] + + return target_logits + + +class AssistantVocabTranslatorCache: + """ + Cache for `AssistantToTargetTranslator` instances. The instances are computed at + pre-processing time, and this cache allows us to avoid recomputing them. + """ + + _cache = weakref.WeakKeyDictionary() + + @classmethod + def get_translator( + cls, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + target_vocab_size: int, + assistant_model_device: str = "cpu", + ) -> AssistantToTargetTranslator: + assistant_dict = cls._cache.get(target_tokenizer) + if assistant_dict is None: + assistant_dict = weakref.WeakKeyDictionary() + cls._cache[target_tokenizer] = assistant_dict + + mapping = assistant_dict.get(assistant_tokenizer) + if mapping is None: + mapping = AssistantToTargetTranslator( + target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device + ) + assistant_dict[assistant_tokenizer] = mapping + + return mapping + + @classmethod + def cleanup(cls): + """ + Clean up dead references in the cache. + This removes entries where either the target_tokenizer or assistant_tokenizer + has been garbage collected. + """ + # Remove entries from the outer cache where the target_tokenizer is no longer alive + dead_keys = [key for key in cls._cache if key is None] + for key in dead_keys: + del cls._cache[key] + + # For each assistant_dict, remove entries where assistant_tokenizer is no longer alive + for assistant_dict in cls._cache.values(): + dead_keys = [key for key in assistant_dict if key is None] + for key in dead_keys: + del assistant_dict[key] + + +class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers): + """ + `CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers + for the assistant and main models. This class generates candidates through the use of a smaller model. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + generation_config: "GenerationConfig", + model_kwargs: Dict, + atm_translator: AssistantToTargetTranslator, + inputs_tensor: Optional[torch.Tensor] = None, + logits_processor: "LogitsProcessorList" = None, + ): + # Initialize translator before parent class + self._atm_translator = atm_translator + super().__init__( + input_ids, + assistant_model, + target_tokenizer, + assistant_tokenizer, + generation_config, + model_kwargs, + inputs_tensor, + logits_processor, + ) + # Track sequence lengths and previous assistant IDs + self._target_seq_len_with_candidates: int = 0 + self._prev_assistant_ids: Optional[torch.LongTensor] = None + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """ + Simplified version of get_candidates that uses the translator cache for token conversion. + """ + target_input_ids = input_ids.to(self.assistant_model.device) + assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) + min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) + + if max_new_tokens == 0: + return input_ids, None + + self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) + + # Ensure scores are returned + generation_args["generation_config"].output_scores = True + generation_args["generation_config"].return_dict_in_generate = True + + # Generate and process outputs using translator + if self._atm_translator.logits_processors is not None: + generation_args["logits_processor"] = self._atm_translator.logits_processors + self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) + + # Use translator to convert tokens and logits + target_candidate_ids = self._atm_translator.get_target_ids( + assistant_input_ids, target_input_ids, self._prev_assistant_ids + ) + self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] + target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) + + return target_candidate_ids, target_candidate_logits + + def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: + if self._prev_assistant_ids is None: + # Prepare attention mask for the first generation. + # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) + return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) + + def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Simplified token conversion that only processes new tokens. + """ + # Calculate new tokens since last call + target_seq_len = target_input_ids.shape[-1] + if self._target_seq_len_with_candidates == 0: + new_token_count = target_seq_len + else: + new_token_count = 1 + target_new_ids = target_input_ids[:, -new_token_count:] + + # Convert the new tokens + assistant_new_ids = None + if self._target_seq_len_with_candidates > 0: + # we have only one new token and we can directly convert it + assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item()) + if assistant_new_ids is None: + target_new_text = self.target_tokenizer.batch_decode( + target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + assistant_new_ids = self.assistant_tokenizer( + target_new_text, add_special_tokens=False, return_tensors="pt" + )["input_ids"].to(self.assistant_model.device) + else: + assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) + + # Update or initialize assistant IDs + if self._prev_assistant_ids is None: + assistant_input_ids = assistant_new_ids + else: + tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len + # If the number of new tokens is greater than zero, truncate the previous assistant IDs + if tokens_to_remove > 0: + self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] + assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + assistant_input_ids = assistant_input_ids.to(dtype=torch.long) + + return assistant_input_ids, len(assistant_new_ids[0]) + + class PromptLookupCandidateGenerator(CandidateGenerator): """ `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3c7445284d9a..5c7fffb117bc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -26,6 +26,8 @@ from torch import nn from torch.nn import functional as F +from transformers.generation.candidate_generator import AssistantVocabTranslatorCache + from ..cache_utils import ( Cache, DynamicCache, @@ -56,6 +58,7 @@ CandidateGenerator, EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, + UniversalSpeculativeDecodingGenerator, _crop_past_key_values, _prepare_attention_mask, _prepare_token_type_ids, @@ -858,16 +861,36 @@ def _get_candidate_generator( max_length=generation_config.max_length, ) elif different_tokenizers: - candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - logits_processor=logits_processor, - target_tokenizer=target_tokenizer, - assistant_tokenizer=assistant_tokenizer, - ) + if generation_config.do_sample is True: + atm_translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device + ) + candidate_generator = UniversalSpeculativeDecodingGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + atm_translator=atm_translator, + ) + elif generation_config.do_sample is False: + candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) + else: + raise ValueError( + f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}" + ) else: candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, @@ -4225,7 +4248,6 @@ def _assisted_decoding( # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 03fd51324b02..38df48ab08d2 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -1,43 +1,325 @@ +import gc import unittest +import weakref +from unittest.mock import MagicMock -import numpy as np +import torch -from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline +from transformers.generation.candidate_generator import ( + AssistantToTargetTranslator, + AssistantVocabTranslatorCache, + UniversalSpeculativeDecodingGenerator, +) +from transformers.testing_utils import require_torch, torch_device -class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): - def test_no_intersection(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[4, 5, 6]]) - result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) - self.assertEqual(result, (None, None, None)) +@require_torch +class TestAssistantToTargetTranslator(unittest.TestCase): + def setUp(self): + # Create mock tokenizers with predefined vocabularies + self.target_tokenizer = MagicMock() + self.assistant_tokenizer = MagicMock() - def test_complete_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens + # Define mock vocabularies for the tokenizers + self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3} + self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4} + + self.target_tokenizer.get_vocab.return_value = self.target_vocab + self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab + self.assistant_model_device = torch_device + self.target_vocab_size = 6 + + # Instantiate the class under test + self.translator = AssistantToTargetTranslator( + target_tokenizer=self.target_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) + + def test_get_assistant_to_target_input_ids(self): + """Test the mapping from assistant tokens to target tokens.""" + expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID] + actual_mapping = self.translator._assistant_to_target_input_ids.tolist() + self.assertEqual(actual_mapping, expected_mapping) + + def test_get_suppress_input_ids(self): + """Test the suppression of assistant input IDs not present in the target vocabulary.""" + expected_suppress_ids = [3, 4] + actual_suppress_ids = self.translator._get_suppress_input_ids().tolist() + self.assertEqual(actual_suppress_ids, expected_suppress_ids) + + def test_get_target_ids(self): + """Test the translation of assistant candidate IDs to target candidate IDs.""" + assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to( + self.assistant_model_device + ) # 'hello world foo' in assistant tokenizer + target_input_ids = torch.LongTensor([[0, 1, 2]]).to( + self.assistant_model_device + ) # 'hello world foo' in target tokenizer + assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to( + self.assistant_model_device + ) # 'hello world foo baz' in assistant tokenizer + + expected_target_ids = torch.LongTensor( + [[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] + ).to( + self.assistant_model_device + ) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab) + + actual_target_ids = self.translator.get_target_ids( + assistant_input_ids, target_input_ids, assistant_candidate_ids + ) + self.assertTrue(torch.equal(actual_target_ids, expected_target_ids)) + + def test_get_target_logits(self): + """Test the conversion of assistant logits to target logits.""" + # Assistant logits for IDs 0, 1, 2 + assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to( + self.assistant_model_device + ) # Shape (1, 1, 5) + + # Expected target logits (target_vocab_size = 4) + expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to( + self.assistant_model_device + ) + expected_target_logits[0, 0, 0] = 0.1 # 'hello' + expected_target_logits[0, 0, 1] = 0.2 # 'world' + expected_target_logits[0, 0, 2] = 0.3 # 'foo' + # The 'bar' token in target vocab remains at -inf + + actual_target_logits = self.translator.get_target_logits(assistant_logits) + self.assertTrue(torch.equal(actual_target_logits, expected_target_logits)) + + +class MockTokenizer: + """A simple mock tokenizer class that supports weak references.""" + + def __init__(self, vocab=None): + self._vocab = vocab or {} + + def get_vocab(self): + return self._vocab + + def __call__(self, text, add_special_tokens=True): + # Mock implementation of the __call__ method + tokens = text.split() + input_ids = [self._vocab.get(token, 0) for token in tokens] + return {"input_ids": input_ids} + + +@require_torch +class TestAssistantVocabTranslatorCache(unittest.TestCase): + def setUp(self): + # Clear the cache before each test + AssistantVocabTranslatorCache._cache.clear() + # Create mock tokenizers with different vocabularies + self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1}) + self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) + self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) + self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) + self.assistant_model_device = torch_device + self.target_vocab_size = 6 + + def test_same_instance_for_same_tokenizers(self): + """Test that the same translator is returned for the same tokenizers.""" + translator1 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) + translator2 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) + self.assertIs(translator1, translator2, "Translators should be cached and identical") - def test_partial_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens + def test_different_instances_for_different_tokenizers(self): + """Test that different tokenizers produce different translators.""" + translator1 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) + translator2 = AssistantVocabTranslatorCache.get_translator( + self.other_target_tokenizer, + self.other_assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) + self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") - def test_no_new_tokens(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens + def test_cache_with_weakref_key(self): + """Ensure that the cache uses weak references as keys.""" + initial_cache_size = len(AssistantVocabTranslatorCache._cache) + target_tokenizer = MockTokenizer({"hello": 0}) + assistant_tokenizer = MockTokenizer({"hello": 0}) + + # Store translator in a local variable to avoid it being kept alive + translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) + self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) + + # Delete all strong references + del target_tokenizer + del assistant_tokenizer + del translator + + # Force garbage collection + gc.collect() + + # Call cleanup to remove dead entries + AssistantVocabTranslatorCache.cleanup() + + # The cache size remains increased due to strong references + self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) + + def test_weakref_cache_cleanup(self): + """Test that the cache cleans up translators when tokenizers are garbage collected.""" + + def create_translator(): + target_tokenizer = MockTokenizer({"hello": 0}) + assistant_tokenizer = MockTokenizer({"hello": 0}) + translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) + # Create weak references before returning + refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) + # Remove strong references inside the function + del target_tokenizer + del assistant_tokenizer + del translator + return refs + + translator_ref, target_ref, assistant_ref = create_translator() + + # Force garbage collection + gc.collect() + + # Call cleanup to remove dead entries + AssistantVocabTranslatorCache.cleanup() + + # The tokenizers and translator are not garbage collected due to strong references + self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references") + self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references") + self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references") + + +@require_torch +class TestUniversalSpeculativeDecoding(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM" + + def setUp(self): + self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name) + self.target_config = AutoConfig.from_pretrained(self.target_name) + self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device) + self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name) + + self.generation_config = GenerationConfig() + + # Ensure required tokens exist + if self.target_tokenizer.pad_token_id is None: + self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id + if self.target_tokenizer.bos_token_id is None: + self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id + if self.assistant_tokenizer.pad_token_id is None: + self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id + if self.target_tokenizer.bos_token_id is None: + self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id + + self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) + self.model_kwargs = { + "attention_mask": torch.ones_like(self.input_ids).to(torch_device), + } + + atm_translator = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device + ) + self.generator = UniversalSpeculativeDecodingGenerator( + input_ids=self.input_ids, + assistant_model=self.assistant_model, + target_tokenizer=self.target_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + generation_config=self.generation_config, + model_kwargs=self.model_kwargs, + atm_translator=atm_translator, + ) + + def test_basic_generation(self): + """Test basic speculative decoding works""" + input_text = "The quick brown fox" + input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + + self.assertIsNotNone(candidates) + self.assertIsNotNone(scores) + self.assertTrue(torch.is_tensor(candidates)) + self.assertTrue(torch.is_tensor(scores)) + + def test_mismatched_vocabularies(self): + """Test handling of mismatched vocabularies between models""" + # Create input with tokens present in main but not assistant vocab + # Find a token that is not in the assistant tokenizer but in + # the main tokenizer. + missing_token = next( + token + for token in self.target_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() + and token not in self.target_tokenizer.all_special_tokens + and "reserved_" not in token + ) + input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + self.assertIsNotNone(candidates) + + def test_speculation_depth(self): + """Test different speculation depths""" + input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") + self.generator.input_ids = input_ids + + for depth in [1, 8, 17]: + self.generator.num_assistant_tokens = depth + candidates, scores = self.generator.get_candidates(input_ids) + self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) + + def test_device_consistency(self): + """Test handling of inputs on different devices""" + input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) + self.generator.input_ids = input_ids + candidates, _ = self.generator.get_candidates(input_ids) + self.assertEqual(candidates.device, input_ids.device) + + def test_usd_vs_vanilla_sampling(cls): + """Test that USD matches vanilla sampling with temperature set to nearly 0""" + prompt = "Test text" + + pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name) + pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature + usd_text = pipe_usd_output[0]["generated_text"] + + pipe_vanilla = pipeline( + "text-generation", + model=cls.target_name, + ) + pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) + vanilla_text = pipe_vanilla_output[0]["generated_text"] + + # Assert that the outputs match + cls.assertEqual(usd_text, vanilla_text)