-
Notifications
You must be signed in to change notification settings - Fork 0
Fix prepare + apply #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6097a8d
71562fc
3b4e9da
1dcdae4
10d1e56
f9a260f
0d3310d
98cd50b
8260624
ff7977e
38d81b1
6a7d3b3
e4e53b9
a19a9de
c4e4186
0ec0788
200f7a0
95bfa2c
1cbc871
4a94849
f1b6b08
df68533
35e354a
a558bd0
5c3ad58
811a4e5
2dcc9ed
f2be0da
83b8250
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,6 @@ | |
| from ..cache_utils import DynamicCache | ||
| from ..pytorch_utils import isin_mps_friendly | ||
| from .logits_process import ( | ||
| LogitNormalization, | ||
| LogitsProcessorList, | ||
| MinLengthLogitsProcessor, | ||
| SuppressTokensLogitsProcessor, | ||
|
|
@@ -245,18 +244,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: | |
| min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) | ||
| return min_new_tokens, max_new_tokens | ||
|
|
||
| def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: | ||
| def _update_past_and_masks( | ||
| self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 | ||
| ) -> bool: | ||
| """Update past key values and attention masks for subsequent generation rounds.""" | ||
| has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None | ||
| if has_past_key_values: | ||
| new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv | ||
| self.assistant_kwargs["past_key_values"] = _crop_past_key_values( | ||
| self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 | ||
| self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens | ||
| ) | ||
| self.assistant_kwargs = _prepare_attention_mask( | ||
| self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder | ||
| ) | ||
| self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) | ||
|
|
||
| return has_past_key_values | ||
|
|
||
| def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: | ||
|
|
@@ -565,34 +567,41 @@ class AssistantToTargetTranslator: | |
| Translate the assistant into the target universe. | ||
| """ | ||
|
|
||
| def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"): | ||
| def __init__( | ||
| self, | ||
| target_tokenizer: "PreTrainedTokenizerBase", | ||
| assistant_tokenizer: "PreTrainedTokenizerBase", | ||
| assistant_model_device, | ||
| target_vocab_size: int, | ||
| filter_value: float = -float("Inf"), | ||
| suppress_tokens_id: int = -1, | ||
| ): | ||
| self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer | ||
| self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer | ||
| self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids() | ||
| self.suppress_input_ids: list[int] = self._get_suppress_input_ids() | ||
| self._assistant_model_device = assistant_model_device | ||
| self.target_vocab_size: int = target_vocab_size | ||
| self.filter_value: float = filter_value | ||
| self.suppress_tokens_id: int = suppress_tokens_id | ||
| self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() | ||
| self.logits_processors: LogitsProcessorList = LogitsProcessorList( | ||
| [ | ||
| SuppressTokensLogitsProcessor(self.suppress_input_ids), | ||
| LogitNormalization(), | ||
| ] | ||
| [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] | ||
| ) | ||
|
|
||
| def _get_assistant_to_target_input_ids(self) -> dict[int, int]: | ||
| """ | ||
| Get a mapping from assistant tokens to target tokens based on vocabularies. | ||
| """ | ||
| def _get_assistant_to_target_input_ids(self): | ||
| target_vocab = self._target_tokenizer.get_vocab() | ||
| assistant_vocab = self._assistant_tokenizer.get_vocab() | ||
| return { | ||
| assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys()) | ||
| } | ||
| max_assistant_index = max(assistant_vocab.values()) | ||
| assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) | ||
| for tok, idx in assistant_vocab.items(): | ||
| if tok in target_vocab: | ||
| assistant_to_target_input_ids[idx] = target_vocab[tok] | ||
| return assistant_to_target_input_ids.to(self._assistant_model_device) | ||
|
|
||
| def _get_suppress_input_ids(self) -> list[int]: | ||
| """ | ||
| Get the input ids that are in the assistant vocab but not in the target vocab. | ||
| """ | ||
| assistant_vocab = self._assistant_tokenizer.get_vocab() | ||
| return list(set(assistant_vocab.values()) - set(self._assistant_to_target_input_ids.keys())) | ||
| return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0] | ||
|
|
||
| def get_target_ids( | ||
| self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor | ||
|
|
@@ -602,33 +611,29 @@ def get_target_ids( | |
| Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. | ||
| Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. | ||
| """ | ||
| device = assistant_candidate_ids.device | ||
| target_candidate_ids = ( | ||
| assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :] | ||
| .cpu() | ||
| .apply_(lambda x: self._assistant_to_target_input_ids.get(x, x)) | ||
| .to(device) | ||
| ) | ||
| return torch.cat((target_input_ids, target_candidate_ids.unsqueeze(0)), dim=1) | ||
|
|
||
| num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] | ||
| if num_new_tokens == 0: | ||
| return target_input_ids | ||
| else: | ||
| transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] | ||
| return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) | ||
|
|
||
| def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: | ||
| """ | ||
| Return the target logits that correspond to the assistant logits. | ||
| """ | ||
| device = assistant_logits.device | ||
| target_vocab_size: int = len(self._target_tokenizer.get_vocab()) | ||
| target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size) | ||
| target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(device) | ||
| assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") | ||
| assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ | ||
| -1 | ||
| ] | ||
| target_logits_supported_indices: torch.IntTensor = ( | ||
| assistant_logits_supported_indices.cpu() | ||
| .apply_(lambda x: self._assistant_to_target_input_ids[x]) | ||
| .to(device) | ||
| ) | ||
| target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] | ||
|
|
||
| target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) | ||
| target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device) | ||
| # Mask for valid indices | ||
| assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id | ||
| # Exclude invalid indices | ||
| target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] | ||
| valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] | ||
|
|
||
| target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] | ||
|
|
||
| return target_logits | ||
|
|
||
|
|
||
|
|
@@ -643,7 +648,11 @@ class AssistantVocabTranslatorCache: | |
|
|
||
| @classmethod | ||
| def get_translator( | ||
| cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase" | ||
| cls, | ||
| target_tokenizer: "PreTrainedTokenizerBase", | ||
| assistant_tokenizer: "PreTrainedTokenizerBase", | ||
| assistant_model_device, | ||
| target_vocab_size: int, | ||
| ) -> AssistantToTargetTranslator: | ||
| with cls._lock: | ||
| assistant_dict = cls._cache.get(target_tokenizer) | ||
|
|
@@ -653,7 +662,9 @@ def get_translator( | |
|
|
||
| mapping = assistant_dict.get(assistant_tokenizer) | ||
| if mapping is None: | ||
| mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer) | ||
| mapping = AssistantToTargetTranslator( | ||
| target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size | ||
| ) | ||
| assistant_dict[assistant_tokenizer] = mapping | ||
|
|
||
| return mapping | ||
|
|
@@ -692,11 +703,14 @@ def __init__( | |
| assistant_tokenizer: "PreTrainedTokenizerBase", | ||
| generation_config: "GenerationConfig", | ||
| model_kwargs: Dict, | ||
| target_vocab_size: int, | ||
| inputs_tensor: Optional[torch.Tensor] = None, | ||
| logits_processor: "LogitsProcessorList" = None, | ||
| ): | ||
| # Initialize translator before parent class | ||
| self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) | ||
| self._atm_translator = AssistantVocabTranslatorCache.get_translator( | ||
| target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size | ||
| ) | ||
| super().__init__( | ||
| input_ids, | ||
| assistant_model, | ||
|
|
@@ -708,52 +722,61 @@ def __init__( | |
| logits_processor, | ||
| ) | ||
| # Track sequence lengths and previous assistant IDs | ||
| self._prev_target_seq_len: int = 0 | ||
| self._target_seq_len_with_candidates: int = 0 | ||
| self._prev_assistant_ids: Optional[torch.LongTensor] = None | ||
| self.target_vocab_size = target_vocab_size | ||
|
|
||
| def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | ||
| """ | ||
| Simplified version of get_candidates that uses the translator cache for token conversion. | ||
| """ | ||
| input_ids = input_ids.to(self.assistant_model.device) | ||
| target_input_ids = input_ids.clone() | ||
| assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) | ||
| target_input_ids = input_ids.to(self.assistant_model.device) | ||
| assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) | ||
| min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) | ||
|
|
||
| if max_new_tokens == 0: | ||
| return input_ids, None | ||
|
|
||
| self._update_past_and_masks(assistant_input_ids) | ||
| self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) | ||
| generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) | ||
| self.assistant_kwargs.pop("attention_mask", None) | ||
|
|
||
| # Ensure scores are returned | ||
| generation_args["generation_config"].output_scores = True | ||
| generation_args["generation_config"].return_dict_in_generate = True | ||
|
|
||
| # Generate and process outputs using translator | ||
| assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) | ||
| self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values | ||
|
|
||
| candidate_logits = torch.stack(assistant_output.scores, dim=1) | ||
| generation_args["logits_processor"] = self._atm_translator.logits_processors | ||
| self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) | ||
|
|
||
| # Use translator to convert tokens and logits | ||
| candidate_ids = assistant_output.sequences | ||
| candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) | ||
| target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) | ||
| target_logits = self._atm_translator.get_target_logits(candidate_logits) | ||
| target_candidate_ids = self._atm_translator.get_target_ids( | ||
| assistant_input_ids, target_input_ids, self._prev_assistant_ids | ||
| ) | ||
| self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] | ||
| target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) | ||
|
|
||
| return target_ids, target_logits | ||
| return target_candidate_ids, target_candidate_logits | ||
|
|
||
| def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: | ||
| if self._prev_assistant_ids is None: | ||
| # Prepare attention mask for the first generation. | ||
| # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. | ||
| self.assistant_kwargs = _prepare_attention_mask( | ||
| self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder | ||
| ) | ||
| return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) | ||
|
|
||
| def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: | ||
| """ | ||
| Simplified token conversion that only processes new tokens. | ||
| """ | ||
| # Calculate new tokens since last call | ||
| target_seq_len = target_input_ids.shape[-1] | ||
| new_token_count = target_seq_len - self._prev_target_seq_len | ||
| if self._target_seq_len_with_candidates == 0: | ||
| new_token_count = target_seq_len | ||
| else: | ||
| new_token_count = 1 | ||
| target_new_ids = target_input_ids[:, -new_token_count:] | ||
| self._prev_target_seq_len = target_seq_len | ||
|
|
||
| # Convert only the new tokens | ||
| target_new_text = self.target_tokenizer.batch_decode( | ||
|
|
@@ -765,11 +788,16 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to | |
|
|
||
| # Update or initialize assistant IDs | ||
| if self._prev_assistant_ids is None: | ||
| self._prev_assistant_ids = assistant_new_ids | ||
| assistant_input_ids = assistant_new_ids | ||
| else: | ||
| self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) | ||
|
|
||
| return self._prev_assistant_ids | ||
| tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len | ||
| # If the number of new tokens is greater than zero, truncate the previous assistant IDs | ||
| if tokens_to_remove > 0: | ||
| self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] | ||
| assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) | ||
| assistant_input_ids = assistant_input_ids.to(torch.int) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the documentation,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean adding before
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wdyt about ensuring we only assign
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we get all the IDs from the tokenizer and their type is |
||
|
|
||
| return assistant_input_ids, len(assistant_new_ids[0]) | ||
|
|
||
|
|
||
| class PromptLookupCandidateGenerator(CandidateGenerator): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.