Skip to content

Commit 6ac164a

Browse files
jinhongyiijunrushao
authored andcommitted
Introduce Mixtral MoE Model
This PR introduces support for Mixtral MoE models with MLC's latest SLM quantization/compilation pipeline. It includes the following pieces of changes: **Operators.** We implemented a list of operators in TIR's TVMScript format in two files `moe_misc` and `moe_matmul`. Those TIR kernels implement "transpose indices" and "blocked-CSR-COO" as described in MegaBlock [1]. `moe_misc.py` primarily concerns sparsity-related operators, including: - `get_indices`, `get_indptr` and `scatter_output`: CSR-style index manipulation and array shuffling that makes the input ranges each expert has to deal with contiguous. - `moe_sum`, `moe_cumsum`, `topk` which are standard operators but specialized for MoE usecases, e.g. #experts and #activated-experts are small. `moe_matmul.py` includes non-quantized and quantized GEMV and GEMV operators used in MoE model serving. Typically, in single batch decoding, GEMV operators should suffice, but group GEMM is a necessary dependency in both prefilling and batched decoding. **Model architecture.** We reuse the attention blocking block from Mistral, and implemented MLP MoE in `mixtral_model.py`. In Mixtral, there are three groups of experts in each MLP, where `e1` and `e3` are gate/up projections (project-in) and `e2` is down project (project-out). **Weight quantization.** We batch all experts of the same kind into a single tensor, whose shape is `(Ne, N, K)`, where `Ne` is the total number of experts, `N` is out features and `K` is in-features. Applying group quantization, we compress along the `K` dimension as consistent with the rest of the project. **Performance.** The current TIR is highly optimized for non-tensor core scenarios, and tensor core performance is left for a PR in the nearest future. **Try out MLC's Mixtral Model.** TBD [1] Gale, Trevor, Deepak Narayanan, Cliff Young, and Matei Zaharia. "MegaBlocks: Efficient Sparse Training with Mixture-of-Experts." Proceedings of MLSys 2023.
1 parent 5e23900 commit 6ac164a

File tree

19 files changed

+1522
-186
lines changed

19 files changed

+1522
-186
lines changed

python/mlc_chat/interface/convert_weight.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def _device_to_str(device: Device) -> str:
5151

5252

5353
def _calc_total_params(model: nn.Module) -> int:
54-
_, named_params, _ = model.export_tvm(spec=model.get_default_spec(), allow_extern=True)
54+
_, named_params, _ = model.export_tvm( # type: ignore[misc]
55+
spec=model.get_default_spec(), # type: ignore[attr-defined]
56+
allow_extern=True,
57+
)
5558
total_params = 0
5659
for _, param in named_params:
5760
total_params += math.prod(param.shape)

python/mlc_chat/model/gpt2/gpt2_model.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
TODO: add docstring
44
"""
55
import dataclasses
6-
import math
76
from typing import Any, Dict, Optional
87

98
from tvm import te, tir
109
from tvm.relax.frontend import nn
1110
from tvm.relax.frontend.nn import Tensor, op
1211

12+
from mlc_chat import op as op_ext
1313
from mlc_chat.support import logging
1414
from mlc_chat.support.config import ConfigBase
1515
from mlc_chat.support.style import bold
@@ -110,29 +110,15 @@ def forward(
110110

111111
self.k_cache.append(op.squeeze(k, axis=0))
112112
self.v_cache.append(op.squeeze(v, axis=0))
113-
k = op.reshape(self.k_cache.view(t), (b, t, h, d))
114-
v = op.reshape(self.v_cache.view(t), (b, t, h, d))
115-
116-
q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
117-
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
118-
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
119-
120-
attn_weights = op.matmul(
121-
q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t]
122-
) / math.sqrt(d)
113+
k = self.k_cache.view(t)
114+
v = self.v_cache.view(t)
123115

124116
if self.scale_attn_by_inverse_layer_idx:
125-
attn_weights = attn_weights / float(self.layer_idx + 1)
126-
127-
dtype = attn_weights.dtype
128-
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask)
129-
if dtype == "float32":
130-
attn_weights = op.softmax(attn_weights, axis=-1)
117+
attn_score_scaling_factor = 1.0 / float(self.layer_idx + 1)
131118
else:
132-
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
133-
# [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d]
134-
output = op.matmul(attn_weights, v)
135-
return self.c_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
119+
attn_score_scaling_factor = 1.0
120+
output = op_ext.attention(q, k, v, attention_mask, attn_score_scaling_factor)
121+
return self.c_proj(output)
136122

137123

138124
class GPT2MLP(nn.Module):

python/mlc_chat/model/llama/llama_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tvm.relax.frontend.nn import Tensor, op
1111

1212
from mlc_chat import op as op_ext
13-
from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache
13+
from mlc_chat.nn import FlashInferPagedKVCache, PagedKVCache
1414
from mlc_chat.support import logging
1515
from mlc_chat.support import tensor_parallel as tp
1616
from mlc_chat.support.config import ConfigBase
@@ -342,7 +342,7 @@ def create_flashinfer_paged_kv_cache(
342342
num_kv_heads = self.num_key_value_heads // self.tensor_parallel_shards
343343
# Note: Right now we only have FlashInfer-based KV cache supported.
344344
# TIR version will be introduced soon.
345-
return FlashInferPagedKVCache.create(
345+
return FlashInferPagedKVCache(
346346
max_batch_size=max_batch_size,
347347
max_total_seq_len=max_total_seq_len,
348348
page_size=page_size,

python/mlc_chat/model/mistral/mistral_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def __init__(self, config: MistralConfig):
358358
[MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
359359
)
360360
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
361-
self.tensor_parallel_shards = config.tensor_parallel_shards > 1
361+
self.tensor_parallel_shards = config.tensor_parallel_shards
362362

363363
def forward( # pylint: disable=too-many-arguments
364364
self,

python/mlc_chat/model/mixtral/__init__.py

Whitespace-only changes.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
import functools
6+
7+
import numpy as np
8+
9+
from mlc_chat.loader import ExternMapping
10+
from mlc_chat.quantization import Quantization
11+
12+
from .mixtral_model import MixtralConfig, MixtralForCasualLM
13+
14+
15+
def huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping:
16+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
17+
the names of HuggingFace PyTorch parameters.
18+
19+
Parameters
20+
----------
21+
model_config : MixtralConfig
22+
The configuration of the Mixtral model.
23+
24+
quantization : Quantization
25+
The quantization configuration.
26+
27+
Returns
28+
-------
29+
param_map : ExternMapping
30+
The parameter mapping from MLC to HuggingFace PyTorch.
31+
"""
32+
model = MixtralForCasualLM(model_config)
33+
if quantization is not None:
34+
model.to(quantization.model_dtype)
35+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
36+
spec=model.get_default_spec(),
37+
allow_extern=True,
38+
)
39+
named_parameters = dict(_named_params)
40+
41+
mapping = ExternMapping()
42+
43+
for i in range(model_config.num_hidden_layers):
44+
# Add QKV in self attention
45+
attn = f"model.layers.{i}.self_attn"
46+
mlc_name = f"{attn}.qkv_proj.weight"
47+
mlc_param = named_parameters[mlc_name]
48+
mapping.add_mapping(
49+
mlc_name,
50+
[
51+
f"{attn}.q_proj.weight",
52+
f"{attn}.k_proj.weight",
53+
f"{attn}.v_proj.weight",
54+
],
55+
functools.partial(
56+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
57+
dtype=mlc_param.dtype,
58+
),
59+
)
60+
61+
# Add gates in MLP (when MoE is enabled)
62+
mlp = f"model.layers.{i}.block_sparse_moe"
63+
mlc_mlp = f"model.layers.{i}.moe"
64+
mlc_name = f"{mlc_mlp}.e1_e3.weight"
65+
mlc_param = named_parameters[mlc_name]
66+
67+
def combine_expert_gate_up(*hf_params, dtype):
68+
stack = []
69+
for i in range(0, len(hf_params), 2):
70+
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
71+
return np.stack(stack, axis=0).astype(dtype)
72+
73+
mapping.add_mapping(
74+
mlc_name,
75+
functools.reduce(
76+
lambda a, b: a + b,
77+
[
78+
[
79+
f"{mlp}.experts.{expert_id}.w1.weight",
80+
f"{mlp}.experts.{expert_id}.w3.weight",
81+
]
82+
for expert_id in range(model_config.num_local_experts)
83+
],
84+
),
85+
functools.partial(
86+
combine_expert_gate_up,
87+
dtype=mlc_param.dtype,
88+
),
89+
)
90+
91+
mlc_name = f"{mlc_mlp}.e2.weight"
92+
mlc_param = named_parameters[mlc_name]
93+
mapping.add_mapping(
94+
mlc_name,
95+
[
96+
f"{mlp}.experts.{expert_id}.w2.weight"
97+
for expert_id in range(model_config.num_local_experts)
98+
],
99+
functools.partial(
100+
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
101+
dtype=mlc_param.dtype,
102+
),
103+
)
104+
105+
mlc_name = f"{mlc_mlp}.gate.weight"
106+
mlc_param = named_parameters[mlc_name]
107+
mapping.add_mapping(
108+
mlc_name,
109+
[f"{mlp}.gate.weight"],
110+
functools.partial(
111+
lambda x, dtype: x.astype(dtype),
112+
dtype=mlc_param.dtype,
113+
),
114+
)
115+
116+
# inv_freq is not used in the model
117+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
118+
119+
for mlc_name, mlc_param in named_parameters.items():
120+
if mlc_name not in mapping.param_map:
121+
mapping.add_mapping(
122+
mlc_name,
123+
[mlc_name],
124+
functools.partial(
125+
lambda x, dtype: x.astype(dtype),
126+
dtype=mlc_param.dtype,
127+
),
128+
)
129+
return mapping
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Implementation for Mistral architecture."""
2+
import dataclasses
3+
4+
from tvm import tir
5+
from tvm.relax.frontend import nn
6+
from tvm.relax.frontend.nn import Tensor, op
7+
8+
from mlc_chat import op as op_ext
9+
from mlc_chat.model.mistral.mistral_model import (
10+
MistralAttention,
11+
MistralConfig,
12+
MistralForCasualLM,
13+
MistralModel,
14+
RotaryEmbedding,
15+
)
16+
from mlc_chat.nn.expert import MixtralExperts
17+
from mlc_chat.support import logging
18+
from mlc_chat.support import tensor_parallel as tp
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
@dataclasses.dataclass
24+
class MixtralConfig(MistralConfig): # pylint: disable=too-many-instance-attributes
25+
"""Configuration of the Mixtral model."""
26+
27+
num_local_experts: int = 0
28+
num_experts_per_tok: int = 0
29+
30+
31+
# pylint: disable=invalid-name,missing-docstring,too-many-locals,fixme
32+
33+
34+
class MixtralMoE(nn.Module):
35+
"""Mixture of experts"""
36+
37+
def __init__(self, config: MixtralConfig):
38+
super().__init__()
39+
self.num_experts_per_tok = config.num_experts_per_tok
40+
self.num_local_experts = config.num_local_experts
41+
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
42+
self.gate = nn.Linear(
43+
in_features=config.hidden_size,
44+
out_features=config.num_local_experts,
45+
bias=False,
46+
)
47+
self.e1_e3 = MixtralExperts(
48+
self.num_local_experts,
49+
in_features=config.hidden_size,
50+
out_features=2 * self.intermediate_size,
51+
)
52+
self.e2 = MixtralExperts(
53+
self.num_local_experts,
54+
in_features=self.intermediate_size,
55+
out_features=config.hidden_size,
56+
)
57+
self.dtype = "float32"
58+
59+
def forward(self, x: Tensor):
60+
def _expert_forward(x: Tensor, indptr: Tensor):
61+
x1_x3 = self.e1_e3(x, indptr)
62+
x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1)
63+
x = self.e2(op.silu(x1) * x3, indptr)
64+
return x
65+
66+
experts_per_tok = self.num_experts_per_tok # activated experts per token
67+
local_experts = self.num_local_experts # total number of experts
68+
batch_size, seq_len, hidden_size = x.shape
69+
num_tokens = batch_size * seq_len
70+
x = x.reshape(num_tokens, hidden_size)
71+
# gate: [num_tokens, local_experts]
72+
gate: Tensor = self.gate(x)
73+
# expert_weights: [num_tokens, experts_per_tok]
74+
# expert_indices: [num_tokens, experts_per_tok]
75+
expert_weights, expert_indices = op_ext.moe_misc.topk(gate, experts_per_tok)
76+
expert_weights = op.softmax(expert_weights.astype("float32"), axis=-1).astype(self.dtype)
77+
if num_tokens == 1:
78+
# x: [num_tokens * experts_per_tok, hidden_size]
79+
x = _expert_forward(x, expert_indices)
80+
else:
81+
# cumsum: [num_tokens * total_experts]
82+
cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, local_experts)
83+
# indices: [num_tokens * experts_per_tok]
84+
indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)
85+
# indptr: [num_local_experts + 1]
86+
indptr = op_ext.moe_misc.get_indptr(cumsum, local_experts)
87+
# x: [num_tokens * experts_per_tok, hidden_size]
88+
x = op.take(x, indices / experts_per_tok, axis=0)
89+
x = _expert_forward(x, indptr)
90+
x = op_ext.moe_misc.scatter_output(x, indices)
91+
# x: [num_tokens, experts_per_tok, hidden_size]
92+
x = x.reshape( # pylint: disable=too-many-function-args
93+
num_tokens, experts_per_tok, hidden_size
94+
) * expert_weights.reshape( # pylint: disable=too-many-function-args
95+
num_tokens, experts_per_tok, 1
96+
)
97+
# x: [num_tokens, hidden_size]
98+
x = op_ext.moe_misc.moe_sum(x, dim=1)
99+
x = x.reshape(batch_size, seq_len, hidden_size) # pylint: disable=too-many-function-args
100+
return x
101+
102+
103+
class MixtralDecoderLayer(nn.Module):
104+
"""Mixtral decoder layer"""
105+
106+
def __init__(self, config: MixtralConfig, rotary_embedding: RotaryEmbedding):
107+
eps = config.rms_norm_eps
108+
self.self_attn = MistralAttention(config, rotary_embedding)
109+
self.moe = MixtralMoE(config)
110+
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)
111+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)
112+
113+
def _set_tp():
114+
def _set(layer, hint):
115+
layer.weight.attrs["shard_strategy"] = hint
116+
117+
hd = config.head_dim
118+
q = self.self_attn.num_q_heads * hd
119+
k = self.self_attn.num_kv_heads * hd
120+
v = self.self_attn.num_kv_heads * hd
121+
i = self.moe.intermediate_size
122+
_set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0))
123+
_set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1))
124+
_set(self.moe.e1_e3, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=1))
125+
_set(self.moe.e2, tp.ShardSingleDim("_shard_mlp_down", dim=2))
126+
127+
self.tensor_parallel_shards = config.tensor_parallel_shards
128+
_set_tp()
129+
130+
def forward( # pylint: disable=too-many-arguments
131+
self,
132+
hidden_states: Tensor,
133+
attention_mask: Tensor,
134+
rolling_cache_len: tir.Var,
135+
kv_seq_len: tir.Var,
136+
cache_offset: tir.Var,
137+
):
138+
"""Forward pass of a decoder layer; calculate attention, and add an residual connection."""
139+
140+
def _apply_residual(out, residual):
141+
if self.tensor_parallel_shards > 1:
142+
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
143+
return out + residual
144+
145+
out = self.self_attn(
146+
self.input_layernorm(hidden_states),
147+
attention_mask,
148+
rolling_cache_len,
149+
kv_seq_len,
150+
cache_offset,
151+
)
152+
hidden_states = _apply_residual(out, residual=hidden_states)
153+
out = self.moe(self.post_attention_layernorm(hidden_states))
154+
hidden_states = _apply_residual(out, residual=hidden_states)
155+
return hidden_states
156+
157+
158+
class MixtralModel(MistralModel):
159+
"""Exact same as LlamaModel."""
160+
161+
def __init__(self, config: MixtralConfig):
162+
super().__init__(config)
163+
rotary_embedding = RotaryEmbedding(config)
164+
self.layers = nn.ModuleList(
165+
[MixtralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
166+
)
167+
168+
169+
class MixtralForCasualLM(MistralForCasualLM):
170+
"""Same as LlamaForCausalLM, except for the use of sliding window attention."""
171+
172+
def __init__(self, config: MixtralConfig):
173+
super().__init__(config)
174+
self.model = MixtralModel(config)

0 commit comments

Comments
 (0)