Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def __init__(
self.target_vocab_size: int = len(self._target_tokenizer.get_vocab())
self.filter_value: float = filter_value
self.suppress_tokens_id: int = suppress_tokens_id
self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids()
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = self._get_assistant_to_target_input_ids()
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
if len(self._suppress_input_ids) > 0:
Expand Down Expand Up @@ -677,10 +677,13 @@ def _get_assistant_to_target_input_ids(self):

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int)
for tok, idx in assistant_vocab.items():
if tok in target_vocab:
assistant_to_target_input_ids[idx] = target_vocab[tok]
return assistant_to_target_input_ids.to(self._assistant_model_device)
target_to_assistant_input_id: Dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids

def _get_suppress_input_ids(self) -> list[int]:
"""
Expand Down Expand Up @@ -864,13 +867,20 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to
new_token_count = 1
target_new_ids = target_input_ids[:, -new_token_count:]

# Convert only the new tokens
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(target_new_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].to(self.assistant_model.device)
# Convert the new tokens
assistant_new_ids = None
if self._target_seq_len_with_candidates > 0:
# we have only one new token and we can directly convert it
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
if assistant_new_ids is None:
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(
target_new_text, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to(self.assistant_model.device)
else:
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)

# Update or initialize assistant IDs
if self._prev_assistant_ids is None:
Expand Down