Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@
"MllamaConfig",
"MllamaProcessor",
],
"models.mlp_speculator": ["MLPSpeculatorConfig"],
"models.mluke": [],
"models.mobilebert": [
"MobileBertConfig",
Expand Down Expand Up @@ -3006,6 +3007,12 @@
"MllamaVisionModel",
]
)
_import_structure["models.mlp_speculator"].extend(
[
"MLPSpeculator",
"MLPSpeculatorPreTrainedModel",
]
)
_import_structure["models.mobilebert"].extend(
[
"MobileBertForMaskedLM",
Expand Down Expand Up @@ -5872,6 +5879,7 @@
MllamaConfig,
MllamaProcessor,
)
from .models.mlp_speculator import MLPSpeculatorConfig
from .models.mobilebert import (
MobileBertConfig,
MobileBertTokenizer,
Expand Down Expand Up @@ -7969,6 +7977,10 @@
MllamaTextModel,
MllamaVisionModel,
)
from .models.mlp_speculator import (
MLPSpeculator,
MLPSpeculatorPreTrainedModel,
)
from .models.mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
Expand Down
110 changes: 107 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


if TYPE_CHECKING:
from ..modeling_outputs import BaseModelOutput
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .configuration_utils import GenerationConfig
Expand All @@ -57,7 +58,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
)

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
def update_candidate_strategy(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs
):
"""
Updates the candidate generation strategy based on the outcomes.

Expand Down Expand Up @@ -219,7 +222,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
return candidate_ids, candidate_logits

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
def update_candidate_strategy(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs
):
"""
Updates the candidate generation strategy based on the outcomes.

Expand Down Expand Up @@ -993,7 +998,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# assisted_generation expects logits as well, but we don't have those here, so returning None
return candidate_input_ids, None

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
def update_candidate_strategy(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs
):
"""
Updates the candidate generation strategy based on the outcomes.

Expand Down Expand Up @@ -1066,6 +1073,103 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
return candidate_ids, candidate_logits


class MLPSpeculatorCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for assisted generation via speculative decoding with MLPSpeculator:
https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/
This class generates candidates through the use of a speculator trained to predict the base model's outputs.

Args:
assistant_model (`PreTrainedModel`):
The speculator model to be used for generating candidates
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
"""

def __init__(
self,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
):
self.assistant_model = assistant_model
self.last_hidden_state = None
self.last_token_id = None
if not generation_config.output_hidden_states:
raise ValueError(
"Speculative decoding with MLPSpeculator requires hidden state from the base model."
"Please set generation_config.output_hidden_states to True"
)

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

Return:
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
if self.last_hidden_state is None:
return input_ids, None

output_ids = self.assistant_model.generate_suffixes(
state=self.last_hidden_state,
ind=self.last_token_id,
topk=self.assistant_model.config.top_k_tokens_per_head,
n_candidates=1,
)

# drop n_candidate dimension; only supporting 1 candidate sequence for now
# [batch_size x n_candidates x n_predict] -> [batch_size x n_predict]
output_ids_squeezed = torch.squeeze(output_ids, 1)
candidate_ids = torch.cat([input_ids, output_ids_squeezed], dim=-1)
return candidate_ids, None

def update_candidate_strategy(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
num_matches: int,
model_outputs: "BaseModelOutput" = None,
valid_tokens: torch.LongTensor = None,
**kwargs,
):
"""
Updates the candidate generation strategy based on the outcomes.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
model_outputs (`BaseModelOutput`):
Current iteration's generation output from the base model containing hidden states
valid_tokens (`torch.LongTensor` of shape `(batch_size, num_matches+1)`):
Token ids for the tokens generated by the model in the current iteration

"""
if not model_outputs:
return

if self.last_hidden_state is None: # first generation iteration
last_token_idx = -1 # use the hidden state from the latest token generated by the base model
else:
last_token_idx = num_matches # use the hidden state of the last valid token

last_layer_idx = -1

# [num_layers x batch_size x num_tokens x hidden_dim] -> [batch_size x 1 x hidden_dim]
self.last_hidden_state = model_outputs.hidden_states[last_layer_idx][:, last_token_idx, :].unsqueeze(1)
self.last_token_id = valid_tokens[:, -1].unsqueeze(1) # shape [batch_size x 1]


def _crop_past_key_values(model, past_key_values, max_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
EarlyExitCandidateGenerator,
MLPSpeculatorCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
_crop_past_key_values,
Expand Down Expand Up @@ -959,6 +960,11 @@ def _get_candidate_generator(
max_matching_ngram_size=generation_config.max_matching_ngram_size,
max_length=generation_config.max_length,
)
elif "MLPSpeculatorPreTrainedModel" in assistant_model.config.architectures:
candidate_generator = MLPSpeculatorCandidateGenerator(
assistant_model=assistant_model,
generation_config=generation_config,
)
elif different_tokenizers:
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
Expand Down Expand Up @@ -4763,7 +4769,9 @@ def _assisted_decoding(
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)

# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
candidate_generator.update_candidate_strategy(
input_ids, new_logits, n_matches, model_outputs=outputs, valid_tokens=valid_tokens
)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
Expand Down Expand Up @@ -4825,6 +4833,7 @@ def _assisted_decoding(

if (
hasattr(candidate_generator, "assistant_model")
and candidate_generator.assistant_model.generation_config
and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
):
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
mistral3,
mixtral,
mllama,
mlp_speculator,
mluke,
mobilebert,
mobilenet_v1,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@
("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"),
("mllama", "MllamaConfig"),
("mlp_speculator", "MLPSpeculatorConfig"),
("mobilebert", "MobileBertConfig"),
("mobilenet_v1", "MobileNetV1Config"),
("mobilenet_v2", "MobileNetV2Config"),
Expand Down Expand Up @@ -549,6 +550,7 @@
("mistral3", "Mistral3"),
("mixtral", "Mixtral"),
("mllama", "Mllama"),
("mlp_speculator", "MLPSpeculator"),
("mluke", "mLUKE"),
("mms", "MMS"),
("mobilebert", "MobileBERT"),
Expand Down
42 changes: 42 additions & 0 deletions src/transformers/models/mlp_speculator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_mlp_speculator": ["MLPSpeculatorConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mlp_speculator"] = [
"MLPSpeculator",
"MLPSpeculatorPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_mlp_speculator import MLPSpeculatorConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mlp_speculator import (
MLPSpeculator,
MLPSpeculatorPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List

from ...configuration_utils import PretrainedConfig


class MLPSpeculatorConfig(PretrainedConfig):
model_type = "mlp_speculator"

def __init__(
self,
vocab_size: int = 32000,
emb_dim: int = 4096,
inner_dim: int = 0,
n_predict: int = 3,
top_k_tokens_per_head: List[int] = [5, 4, 3],
n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
**kwargs,
):
"""
Initialize an MLPSpeculatorConfig

Args:
vocab_size: int
the model vocab size
emb_dim: int
the model embedding dimension
inner_dim: int
the inner dimension of the model. If 0, will be the emb_dim.
n_predict: int
the number of lookaheads for the speculator
top_k_tokens_per_head: List[int]
Number of tokens to consider from each head when forming the candidate tree.
For each candidate branch in the tree, head n produces topk[n] additional sub-branches.
n_candidates: int
number of child candidates to create per sequence
tie_weights : bool
If true, use a single set of weights for every model head/stage after the first.
The initial projection from the base model may have a different size, so that stays separate.
scale_input: bool
If true, apply an extra layernorm to the initial state vector input.
Helps training dynamics, particularly when base model output has unusual scale.
"""
assert len(top_k_tokens_per_head) == n_predict
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.inner_dim = inner_dim
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.tie_weights = tie_weights
self.scale_input = scale_input
super().__init__(**kwargs)
Loading