diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 82b57928f3ab..ccd8acdade6c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -620,6 +620,7 @@ "MllamaConfig", "MllamaProcessor", ], + "models.mlp_speculator": ["MLPSpeculatorConfig"], "models.mluke": [], "models.mobilebert": [ "MobileBertConfig", @@ -3006,6 +3007,12 @@ "MllamaVisionModel", ] ) + _import_structure["models.mlp_speculator"].extend( + [ + "MLPSpeculator", + "MLPSpeculatorPreTrainedModel", + ] + ) _import_structure["models.mobilebert"].extend( [ "MobileBertForMaskedLM", @@ -5872,6 +5879,7 @@ MllamaConfig, MllamaProcessor, ) + from .models.mlp_speculator import MLPSpeculatorConfig from .models.mobilebert import ( MobileBertConfig, MobileBertTokenizer, @@ -7969,6 +7977,10 @@ MllamaTextModel, MllamaVisionModel, ) + from .models.mlp_speculator import ( + MLPSpeculator, + MLPSpeculatorPreTrainedModel, + ) from .models.mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index fe57f532e687..a5eb3b6d5c0f 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: + from ..modeling_outputs import BaseModelOutput from ..modeling_utils import PreTrainedModel from ..tokenization_utils_base import PreTrainedTokenizerBase from .configuration_utils import GenerationConfig @@ -57,7 +58,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." ) - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs + ): """ Updates the candidate generation strategy based on the outcomes. @@ -219,7 +222,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, candidate_ids, candidate_logits = self._generate_candidates(generation_args) return candidate_ids, candidate_logits - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs + ): """ Updates the candidate generation strategy based on the outcomes. @@ -993,7 +998,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # assisted_generation expects logits as well, but we don't have those here, so returning None return candidate_input_ids, None - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, **kwargs + ): """ Updates the candidate generation strategy based on the outcomes. @@ -1066,6 +1073,103 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return candidate_ids, candidate_logits +class MLPSpeculatorCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation via speculative decoding with MLPSpeculator: + https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/ + This class generates candidates through the use of a speculator trained to predict the base model's outputs. + + Args: + assistant_model (`PreTrainedModel`): + The speculator model to be used for generating candidates + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. + """ + + def __init__( + self, + assistant_model: "PreTrainedModel", + generation_config: "GenerationConfig", + ): + self.assistant_model = assistant_model + self.last_hidden_state = None + self.last_token_id = None + if not generation_config.output_hidden_states: + raise ValueError( + "Speculative decoding with MLPSpeculator requires hidden state from the base model." + "Please set generation_config.output_hidden_states to True" + ) + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be + assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, + vocabulary_size)` containing the logits associated to each candidate. + """ + if self.last_hidden_state is None: + return input_ids, None + + output_ids = self.assistant_model.generate_suffixes( + state=self.last_hidden_state, + ind=self.last_token_id, + topk=self.assistant_model.config.top_k_tokens_per_head, + n_candidates=1, + ) + + # drop n_candidate dimension; only supporting 1 candidate sequence for now + # [batch_size x n_candidates x n_predict] -> [batch_size x n_predict] + output_ids_squeezed = torch.squeeze(output_ids, 1) + candidate_ids = torch.cat([input_ids, output_ids_squeezed], dim=-1) + return candidate_ids, None + + def update_candidate_strategy( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + num_matches: int, + model_outputs: "BaseModelOutput" = None, + valid_tokens: torch.LongTensor = None, + **kwargs, + ): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + model_outputs (`BaseModelOutput`): + Current iteration's generation output from the base model containing hidden states + valid_tokens (`torch.LongTensor` of shape `(batch_size, num_matches+1)`): + Token ids for the tokens generated by the model in the current iteration + + """ + if not model_outputs: + return + + if self.last_hidden_state is None: # first generation iteration + last_token_idx = -1 # use the hidden state from the latest token generated by the base model + else: + last_token_idx = num_matches # use the hidden state of the last valid token + + last_layer_idx = -1 + + # [num_layers x batch_size x num_tokens x hidden_dim] -> [batch_size x 1 x hidden_dim] + self.last_hidden_state = model_outputs.hidden_states[last_layer_idx][:, last_token_idx, :].unsqueeze(1) + self.last_token_id = valid_tokens[:, -1].unsqueeze(1) # shape [batch_size x 1] + + def _crop_past_key_values(model, past_key_values, max_length): """Crops the past key values up to a certain maximum length.""" new_past = [] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 232cceeedf6d..deb06d1d16b1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -58,6 +58,7 @@ AssistedCandidateGeneratorDifferentTokenizers, CandidateGenerator, EarlyExitCandidateGenerator, + MLPSpeculatorCandidateGenerator, PromptLookupCandidateGenerator, UniversalSpeculativeDecodingGenerator, _crop_past_key_values, @@ -959,6 +960,11 @@ def _get_candidate_generator( max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, ) + elif "MLPSpeculatorPreTrainedModel" in assistant_model.config.architectures: + candidate_generator = MLPSpeculatorCandidateGenerator( + assistant_model=assistant_model, + generation_config=generation_config, + ) elif different_tokenizers: if generation_config.do_sample is True: atm_translator = AssistantVocabTranslatorCache.get_translator( @@ -4763,7 +4769,9 @@ def _assisted_decoding( outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + candidate_generator.update_candidate_strategy( + input_ids, new_logits, n_matches, model_outputs=outputs, valid_tokens=valid_tokens + ) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( @@ -4825,6 +4833,7 @@ def _assisted_decoding( if ( hasattr(candidate_generator, "assistant_model") + and candidate_generator.assistant_model.generation_config and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" ): candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b2cde9f4bc57..a955f2b2f014 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -173,6 +173,7 @@ mistral3, mixtral, mllama, + mlp_speculator, mluke, mobilebert, mobilenet_v1, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9937b55a8b0f..4fb27db684ef 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -196,6 +196,7 @@ ("mistral3", "Mistral3Config"), ("mixtral", "MixtralConfig"), ("mllama", "MllamaConfig"), + ("mlp_speculator", "MLPSpeculatorConfig"), ("mobilebert", "MobileBertConfig"), ("mobilenet_v1", "MobileNetV1Config"), ("mobilenet_v2", "MobileNetV2Config"), @@ -549,6 +550,7 @@ ("mistral3", "Mistral3"), ("mixtral", "Mixtral"), ("mllama", "Mllama"), + ("mlp_speculator", "MLPSpeculator"), ("mluke", "mLUKE"), ("mms", "MMS"), ("mobilebert", "MobileBERT"), diff --git a/src/transformers/models/mlp_speculator/__init__.py b/src/transformers/models/mlp_speculator/__init__.py new file mode 100644 index 000000000000..2145e5e5cf00 --- /dev/null +++ b/src/transformers/models/mlp_speculator/__init__.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_mlp_speculator": ["MLPSpeculatorConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mlp_speculator"] = [ + "MLPSpeculator", + "MLPSpeculatorPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_mlp_speculator import MLPSpeculatorConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mlp_speculator import ( + MLPSpeculator, + MLPSpeculatorPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/mlp_speculator/configuration_mlp_speculator.py b/src/transformers/models/mlp_speculator/configuration_mlp_speculator.py new file mode 100644 index 000000000000..2aa45ea61e05 --- /dev/null +++ b/src/transformers/models/mlp_speculator/configuration_mlp_speculator.py @@ -0,0 +1,54 @@ +from typing import List + +from ...configuration_utils import PretrainedConfig + + +class MLPSpeculatorConfig(PretrainedConfig): + model_type = "mlp_speculator" + + def __init__( + self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: List[int] = [5, 4, 3], + n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, + **kwargs, + ): + """ + Initialize an MLPSpeculatorConfig + + Args: + vocab_size: int + the model vocab size + emb_dim: int + the model embedding dimension + inner_dim: int + the inner dimension of the model. If 0, will be the emb_dim. + n_predict: int + the number of lookaheads for the speculator + top_k_tokens_per_head: List[int] + Number of tokens to consider from each head when forming the candidate tree. + For each candidate branch in the tree, head n produces topk[n] additional sub-branches. + n_candidates: int + number of child candidates to create per sequence + tie_weights : bool + If true, use a single set of weights for every model head/stage after the first. + The initial projection from the base model may have a different size, so that stays separate. + scale_input: bool + If true, apply an extra layernorm to the initial state vector input. + Helps training dynamics, particularly when base model output has unusual scale. + """ + assert len(top_k_tokens_per_head) == n_predict + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.inner_dim = inner_dim + self.n_predict = n_predict + self.top_k_tokens_per_head = top_k_tokens_per_head + self.n_candidates = n_candidates + self.tie_weights = tie_weights + self.scale_input = scale_input + super().__init__(**kwargs) diff --git a/src/transformers/models/mlp_speculator/modeling_mlp_speculator.py b/src/transformers/models/mlp_speculator/modeling_mlp_speculator.py new file mode 100644 index 000000000000..67dadcff7f00 --- /dev/null +++ b/src/transformers/models/mlp_speculator/modeling_mlp_speculator.py @@ -0,0 +1,323 @@ +import math +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...modeling_utils import PreTrainedModel +from .configuration_mlp_speculator import MLPSpeculatorConfig + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value + fits in the range of your encoding scheme + (i.e. fp16 requires eps >= 6e-8). + elementwise_scale_and_shift : bool + Include a learned scaling and shift term after normalization. + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_scale_and_shift=True, + ): + super().__init__() + self.elementwise_scale_and_shift = elementwise_scale_and_shift + if self.elementwise_scale_and_shift: + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + if self.elementwise_scale_and_shift: + x = self.weight * x + x = x + self.bias + return x + + +class MLPSpeculator(nn.Module): + """ + This is a simple MLP-based speculator that functions similarly to Medusa + (https://arxiv.org/abs/2401.10774), ingesting context via the final embedding + vector from the base model. However, this model also conditions on previously + predicted tokens, similarly to an RNN, allowing it to generate better-quality n-grams. + + The architecture is as flat and simple as possible: for each prediction head, + the current state vector is projected into a new latent space and added to the + previous token's embedding. This sum goes through layernorm and activation, forming + the new state vector. This state predicts the next token (or set of candidate tokens) + for the current head, and then is passed on to the next. + ... + Args + ---- + emb_dim : int + Dimensionality of the input vector from the base model. + inner_dim : int + Latent dimensionality of the speculator model. + vocab_size : int + Number of entries in the tokenizer associated with the base model. + n_predict : int + Number of heads / number of tokens to guess ahead. Model size and speed scale with this value. + tie_weights : bool + If true, use a single set of weights for every model head/stage after the first. + The initial projection from the base model may have a different size, so that stays separate. + scale_input: bool + If true, apply an extra layernorm to the initial state vector input. + Helps training dynamics, particularly when base model output has unusual scale. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.n_predict = config.n_predict + self.emb_dim = config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim + self.vocab_size = config.vocab_size + self.tie_weights = config.tie_weights + self.scale_input = config.scale_input + + self.emb = nn.ModuleList([nn.Embedding(self.vocab_size, self.inner_dim) for _ in range(self.n_predict)]) + self.proj = nn.ModuleList( + [ + nn.Linear((self.emb_dim if i == 0 else self.inner_dim), self.inner_dim, bias=False) + for i in range(self.n_predict) + ] + ) + self.head = nn.ModuleList( + [nn.Linear(self.inner_dim, self.vocab_size, bias=False) for _ in range(self.n_predict)] + ) + self.ln = nn.ModuleList( + [MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) for _ in range(self.n_predict)] + ) + if self.scale_input: + self.ln0 = MLPSpeculatorLayerNorm(self.emb_dim, elementwise_scale_and_shift=False) + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) + self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2)) + self.activation = nn.GELU() + + # Handle weight tying as specified + if self.tie_weights: + assert self.n_predict > 1, "You cannot tie weights between stages when only 1 exists" + for emb in self.emb: + emb.weight = self.emb[0].weight + + for head in self.head: + head.weight = self.head[0].weight + + for ln in self.ln: + ln.weight = self.ln[0].weight + ln.bias = self.ln[0].bias + + # Since first proj has different size, allow different initial proj from base into model + for i in range(2, self.n_predict): + self.proj[i].weight = self.proj[1].weight + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim)) + elif isinstance(m, MLPSpeculatorLayerNorm) and hasattr(m, "weight"): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def generate_suffixes( + self, + state: torch.Tensor, + ind: torch.Tensor, + topk: List[int] = [5, 4, 3], + n: int = 5, + ) -> torch.Tensor: + """ + FOR INFERENCE + Generate tree of candidate sequences. + ... + Args + ---- + state : torch.Tensor + Most recent embedding vector from the base model (pre-classification head). + Expects size [b 1 d] where b is batch size and d is model width. + ind : torch.Tensor + Token indices of the base model's most recent predicted token(s). + Expects size [b 1] where b is batch size. + topk : List(int) + Number of tokens to consider from each head when forming the candidate tree. + For each candidate branch in the tree, head n produces topk[n] additional sub-branches. + n : int + Given the final tree of prod(topk) candidates, return only the top n most confident. + ... + Output : torch.Tensor + The tensor of most likely candidate sequences. + Has size [b n self.n_predict], where b is batch size and n is provided above. + """ + # k indicates # of candidates + # h indicates # of generated tokens + b = state.size(0) + k = math.prod(topk) + out = torch.empty(b, 1, k, self.n_predict, device=state.device).int() # b 1 k h -> b k 1 h + log_probs = torch.zeros(b, 1, k, device=state.device) # b 1 k -> b k 1 + assert len(topk) == self.n_predict, ( + f"You must provide a topk number for each head ({self.n_predict} heads, {len(topk)} provided)" + ) + if self.scale_input: + state = self.ln0(state) / (2**0.5) + for i in range(self.n_predict): + # Project and predict + z = self.emb[i](ind) # b k d + state = self.proj[i](state) + # Weighted add of state_weight*state and emb_weight*z + # Let subsequent LN take care of denominator + # state_weight is close to 1, so shouldn't be any precision issues + state = torch.add(state, z, alpha=self.emb_weight / self.state_weight) + state = self.activation(self.ln[i](state)) # b k d + probs = F.log_softmax(self.head[i](state), dim=2) # b k v + probs, preds = probs.topk(topk[i], dim=2) # b k k' + + # Update candidate set with new predictions, repeating shared prefixes as needed + out = out.view(b, preds.size(1) * preds.size(2), -1, self.n_predict) + out[:, :, :, i] = preds.view(b, -1, 1) + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand(-1, -1, topk[i], -1) # b k k' d + state = state.reshape(b, -1, state.size(3)) # b kk' d + ind = preds.view(b, -1) # b kk' + log_probs = log_probs.view(b, probs.size(1) * probs.size(2), -1) + log_probs = log_probs.add(probs.view(b, -1, 1)) + + # Take only top n best guesses + out = out.view(b, k, self.n_predict) + log_probs = log_probs.view(b, k) + best_guesses = log_probs.topk(n, dim=1)[1] # b k + return out.gather(1, best_guesses.unsqueeze(2).expand(-1, -1, self.n_predict)) # b n h + + def forward( + self, + state: torch.Tensor, + inds: torch.Tensor, + ) -> torch.Tensor: + """ + FOR TRAINING + A parallel forward pass on pre-existing ground-truth tokens in pretraining contexts. + Produces self.n_predict predicted tokens for each token embedding in state. + Inds requires self.n_predict extra tokens on the right to "simulate" recursive + behavior for end positions. + ... + Args + ---- + state : torch.Tensor + Embedding vectors from the base model for a given sequence. + Expects size [b n d] where b is batch size, n is seq len, and d is model width. + inds : torch.Tensor + Ground-truth token indices. inds[:,i] is the prediction coming from state[:,i] + (or the legal fiction ground truth corresponding to that prediction). + Expects size [b n+self.n_predict]. + ... + Output : torch.Tensor + Prediction logits at each position, for each head of the speculator. + Has size [self.n_predict b n v] where v is vocab size. + """ + out = [] + if self.scale_input: + state = self.ln0(state) / (2**0.5) + for i in range(self.n_predict): + z = self.emb[i](inds[:, i : i + state.size(1)]) # b n d + state = self.proj[i](state) + # Weighted add of state_weight*state and emb_weight*z + # Let subsequent LN take care of denominator + # state_weight is close to 1, so shouldn't be any precision issues + state = torch.add(state, z, alpha=self.emb_weight / self.state_weight) + state = self.activation(self.ln[i](state)) # b n d + out.append(self.head[i](state)) # b n v + return torch.stack(out, dim=0) # h b n v + + +class MLPSpeculatorPreTrainedModel(PreTrainedModel): + """ + Huggingface MLPSpeculator which provides loading/saving in huggingface + """ + + config_class = MLPSpeculatorConfig + _no_split_modules = ["MLPSpeculator"] + + def __init__(self, config: MLPSpeculatorConfig, speculator: Optional[MLPSpeculator] = None): + super().__init__(config) + if speculator is None: + self.speculator = MLPSpeculator(config) + self.speculator.reset_parameters() + else: + self.speculator = speculator + + def generate_suffixes( + self, + state: torch.Tensor, + ind: torch.Tensor, + topk: List[int] = [5, 4, 3], + n_candidates: int = 5, + ) -> torch.Tensor: + """ + FOR INFERENCE + Generate tree of candidate sequences. + ... + Args + ---- + state : torch.Tensor + Most recent embedding vector from the base model (pre-classification head). + Expects size [b 1 d] where b is batch size and d is model width. + ind : torch.Tensor + Token indices of the base model's most recent predicted token(s). + Expects size [b 1] where b is batch size. + topk : List(int) + Number of tokens to consider from each head when forming the candidate tree. + For each candidate branch in the tree, head n produces topk[n] additional sub-branches. + n_candidates : int + Given the final tree of prod(topk) candidates, return only the top n most confident. + ... + Output : torch.Tensor + The tensor of most likely candidate sequences. + Has size [b n self.n_predict], where b is batch size and n is provided above. + """ + return self.speculator.generate_suffixes(state, ind, topk, n_candidates) + + def forward( + self, + state: torch.Tensor, + inds: torch.Tensor, + ) -> torch.Tensor: + """ + FOR TRAINING + A parallel forward pass on pre-existing ground-truth tokens in pretraining contexts. + Produces self.n_predict predicted tokens for each token embedding in state. + Inds requires self.n_predict extra tokens on the right to "simulate" recursive + behavior for end positions. + ... + Args + ---- + state : torch.Tensor + Embedding vectors from the base model for a given sequence. + Expects size [b n d] where b is batch size, n is seq len, and d is model width. + inds : torch.Tensor + Ground-truth token indices. inds[:,i] is the prediction coming from state[:,i] + (or the legal fiction ground truth corresponding to that prediction). + Expects size [b n+self.n_predict]. + ... + Output : torch.Tensor + Prediction logits at each position, for each head of the speculator. + Has size [self.n_predict b n v] where v is vocab size. + """ + return self.speculator(state, inds) + + def reset_parameters(self): + self.speculator.reset_parameters() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 79b2fd4e2327..fcf1cb41a1c7 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6526,6 +6526,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MLPSpeculator(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MLPSpeculatorPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MobileBertForMaskedLM(metaclass=DummyObject): _backends = ["torch"]