Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/text_generation_server/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
TensorParallelRowLinear,
TensorParallelEmbedding,
)
from text_generation_server.layers.speculative import SpeculativeHead
from text_generation_server.layers.linear import (
get_linear,
FastLinear,
)
from text_generation_server.layers.speculative import SpeculativeHead

# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
Expand Down
35 changes: 19 additions & 16 deletions server/text_generation_server/layers/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,24 @@ def load(config, prefix: str, weights):
from safetensors import safe_open
import json

use_medusa = config.use_medusa
speculator = config.speculator

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
path = speculator["path"]
medusa_config = str(Path(path) / "config.json")

with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)

with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename

medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
Expand All @@ -108,10 +111,10 @@ def __init__(self, config, prefix, weights):
from safetensors import safe_open
import json

use_medusa = config.use_medusa
speculator = config.speculator

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
medusa_config = str(Path(speculator) / "config.json")
filename = str(Path(speculator) / "medusa_lm_head.safetensors")

with open(medusa_config, "r") as f:
medusa_config = json.load(f)
Expand Down
176 changes: 176 additions & 0 deletions server/text_generation_server/layers/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate


class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""

def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x


class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)

# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict

def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head

# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'

# Update candidate set with new predictions

# Update distribution set with new logits
all_probs[:, i] = probs.exp()

# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'

speculative_logits = all_probs
return speculative_logits


class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator

def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None

input_ids = logits.argmax(dim=-1)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits

@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open

speculator_path = config.speculator["path"]

for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename

mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)
43 changes: 30 additions & 13 deletions server/text_generation_server/layers/speculative.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,51 @@
import torch
import json
from typing import Tuple, Optional
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead


class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, medusa):
def __init__(self, lm_head, speculator):
super().__init__()
self.head = lm_head
self.medusa = medusa
self.speculator = speculator

@staticmethod
def load(config, prefix: str, weights):
use_medusa = config.use_medusa
if use_medusa:
lm_head = None
speculator = config.speculator
if speculator:
speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")

with open(speculator_config, "r") as f:
speculator_config = json.load(f)

config.speculator_config = speculator_config
try:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
architecture = speculator_config["architectures"][0]

if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None
return SpeculativeHead(lm_head, medusa)
speculator = None
return SpeculativeHead(lm_head, speculator)

def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None:
return self.medusa(input)
if self.speculator is not None:
return self.speculator(input)

assert self.head is not None
logits = self.head(input)
Expand Down
Loading