Skip to content

Commit 833d357

Browse files
committed
Back to non-fused MOE.
To back off vllm-project#2542 manually.
1 parent 4a64124 commit 833d357

File tree

1 file changed

+95
-179
lines changed

1 file changed

+95
-179
lines changed

vllm/model_executor/models/mixtral.py

Lines changed: 95 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
"""Inference-only Mixtral model."""
2424
from typing import Iterable, List, Optional, Tuple
2525

26+
import numpy as np
27+
2628
import torch
2729
from torch import nn
2830
from transformers import MixtralConfig
@@ -37,6 +39,7 @@
3739
from vllm.model_executor.layers.layernorm import RMSNorm
3840
from vllm.model_executor.layers.linear import (QKVParallelLinear,
3941
ReplicatedLinear,
42+
LinearMethodBase,
4043
RowParallelLinear)
4144
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4245
from vllm.model_executor.layers.quantization.base_config import (
@@ -53,195 +56,112 @@
5356
from vllm.utils import print_warning_once
5457

5558

56-
class MixtralMoE(nn.Module):
57-
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
58-
across all ranks.
59-
60-
Each expert's weights are sharded across all ranks and a fused MoE
61-
kernel is used for the forward pass, and finally we reduce the outputs
62-
across ranks.
63-
"""
59+
class MixtralMLP(nn.Module):
6460

6561
def __init__(
6662
self,
6763
num_experts: int,
68-
top_k: int,
6964
hidden_size: int,
7065
intermediate_size: int,
71-
params_dtype: Optional[torch.dtype] = None,
72-
tp_size: Optional[int] = None,
7366
quant_config: Optional[QuantizationConfig] = None,
74-
):
67+
) -> None:
7568
super().__init__()
76-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
77-
self.num_total_experts = num_experts
78-
self.top_k = top_k
79-
self.hidden_size = hidden_size
80-
self.intermediate_size = intermediate_size // self.tp_size
81-
self.quant_config = quant_config
69+
self.num_experts = num_experts
70+
self.ffn_dim = intermediate_size
71+
self.hidden_dim = hidden_size
72+
73+
self.w1 = ReplicatedLinear(self.hidden_dim,
74+
self.ffn_dim,
75+
bias=False,
76+
quant_config=quant_config)
77+
self.w2 = ReplicatedLinear(self.ffn_dim,
78+
self.hidden_dim,
79+
bias=False,
80+
quant_config=quant_config)
81+
self.w3 = ReplicatedLinear(self.hidden_dim,
82+
self.ffn_dim,
83+
bias=False,
84+
quant_config=quant_config)
85+
86+
# TODO: Use vllm's SiluAndMul
87+
self.act_fn = nn.SiLU()
88+
89+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
90+
w1_out, _ = self.w1(hidden_states)
91+
w1_out = self.act_fn(w1_out)
92+
w3_out, _ = self.w3(hidden_states)
93+
current_hidden_states = w1_out * w3_out
94+
current_hidden_states, _ = self.w2(current_hidden_states)
95+
return current_hidden_states
8296

83-
# FIXME(pcmoritz): Make this more general to support different
84-
# quantization schemes
85-
self.use_fp8 = isinstance(quant_config, Fp8Config)
8697

87-
if params_dtype is None:
88-
params_dtype = torch.get_default_dtype()
89-
self.params_dtype = params_dtype
98+
class MixtralMoE(nn.Module):
9099

91-
# Gate always runs at half / full precision for now.
92-
self.gate = ReplicatedLinear(self.hidden_size,
100+
def __init__(
101+
self,
102+
config: MixtralConfig,
103+
quant_config: Optional[QuantizationConfig] = None,
104+
):
105+
super().__init__()
106+
self.config = config
107+
self.rank = get_tensor_model_parallel_rank()
108+
self.tp_size = get_tensor_model_parallel_world_size()
109+
self.num_total_experts = config.num_local_experts
110+
self.top_k = config.num_experts_per_tok
111+
if self.tp_size > self.num_total_experts:
112+
raise ValueError(
113+
f"Tensor parallel size {self.tp_size} is greater than "
114+
f"the number of experts {self.num_total_experts}.")
115+
# Split experts equally between ranks
116+
self.expert_indicies = np.array_split(range(
117+
self.num_total_experts), self.tp_size)[self.rank].tolist()
118+
if not self.expert_indicies:
119+
raise ValueError(
120+
f"Rank {self.rank} has no experts assigned to it.")
121+
122+
self.experts = nn.ModuleList([
123+
MixtralMLP(self.num_total_experts,
124+
config.hidden_size,
125+
config.intermediate_size,
126+
quant_config=quant_config)
127+
if idx in self.expert_indicies else None
128+
for idx in range(self.num_total_experts)
129+
])
130+
self.gate = ReplicatedLinear(config.hidden_size,
93131
self.num_total_experts,
94132
bias=False,
95-
params_dtype=self.params_dtype,
96-
quant_config=None)
97-
98-
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
99-
params_dtype = torch.float8_e4m3fn
100-
101-
self.w13_weight = nn.Parameter(
102-
torch.empty(self.num_total_experts,
103-
2 * self.intermediate_size,
104-
self.hidden_size,
105-
dtype=params_dtype))
106-
self.w2_weight = nn.Parameter(
107-
torch.empty(self.num_total_experts,
108-
self.hidden_size,
109-
self.intermediate_size,
110-
dtype=params_dtype))
111-
112-
set_weight_attrs(self.w13_weight, {
113-
"weight_loader": self.weight_loader,
114-
})
115-
set_weight_attrs(self.w2_weight, {
116-
"weight_loader": self.weight_loader,
117-
})
118-
119-
# Used for fp8.
120-
self.w13_scale = None
121-
self.w2_scale = None
122-
self.a13_scale = None
123-
self.a2_scale = None
124-
125-
if self.use_fp8:
126-
# WEIGHT_SCALE (for fp8)
127-
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
128-
dtype=torch.float32),
129-
requires_grad=False)
130-
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
131-
dtype=torch.float32),
132-
requires_grad=False)
133-
134-
# If loading fp8 checkpoint, pass the weight loaders.
135-
# If loading an fp16 checkpoint, do not (we will quantize in
136-
# process_weights_after_loading()
137-
if quant_config.is_checkpoint_fp8_serialized:
138-
set_weight_attrs(self.w13_scale, {
139-
"weight_loader": self.weight_loader,
140-
})
141-
set_weight_attrs(self.w2_scale, {
142-
"weight_loader": self.weight_loader,
143-
})
144-
145-
# ACT_SCALE (for fp8)
146-
if quant_config.activation_scheme == "static":
147-
if not quant_config.is_checkpoint_fp8_serialized:
148-
raise ValueError(
149-
"Found static activation scheme for checkpoint that "
150-
"was not serialized fp8.")
151-
self.a13_scale = nn.Parameter(torch.zeros(
152-
self.num_total_experts, dtype=torch.float32),
153-
requires_grad=False)
154-
self.a2_scale = nn.Parameter(torch.zeros(
155-
self.num_total_experts, dtype=torch.float32),
156-
requires_grad=False)
157-
158-
set_weight_attrs(self.a13_scale, {
159-
"weight_loader": self.weight_loader,
160-
})
161-
set_weight_attrs(self.a2_scale, {
162-
"weight_loader": self.weight_loader,
163-
})
164-
165-
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
166-
weight_name: str, expert_id: int):
167-
tp_rank = get_tensor_model_parallel_rank()
168-
param_data = param.data
169-
shard_size = self.intermediate_size
170-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
171-
if weight_name.endswith("w1.weight"):
172-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
173-
if weight_name.endswith("w3.weight"):
174-
param_data[expert_id,
175-
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
176-
if weight_name.endswith("w2.weight"):
177-
param_data[expert_id, :, :] = loaded_weight[:, shard]
178-
if "act_scale" in weight_name or "weight_scale" in weight_name:
179-
param_data[expert_id] = loaded_weight
180-
181-
def process_weights_after_loading(self):
182-
# Fp8 is the only case where we need to process after loading.
183-
if not self.use_fp8:
184-
return
185-
186-
# If checkpoint is fp16, quantize here.
187-
if not self.quant_config.is_checkpoint_fp8_serialized:
188-
w13_weight = torch.empty_like(self.w13_weight.data,
189-
dtype=torch.float8_e4m3fn)
190-
w2_weight = torch.empty_like(self.w2_weight.data,
191-
dtype=torch.float8_e4m3fn)
192-
for expert in range(self.num_total_experts):
193-
w13_weight[expert, :, :], self.w13_scale[
194-
expert] = ops.scaled_fp8_quant(
195-
self.w13_weight.data[expert, :, :])
196-
w2_weight[expert, :, :], self.w2_scale[
197-
expert] = ops.scaled_fp8_quant(
198-
self.w2_weight.data[expert, :, :])
199-
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
200-
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
201-
202-
# If checkpoint is fp8 + static, cleanup act_scales.
203-
# Since state_dict has an act_scale per expert but our kernels
204-
# are passed one act_scale shared across all experts.
205-
elif self.quant_config.activation_scheme == "static":
206-
if self.a13_scale is None or self.a2_scale is None:
207-
raise ValueError(
208-
"QuantConfig has static quantization, but found "
209-
"activation scales are None.")
210-
211-
if (not all_close_1d(self.a13_scale)
212-
or not all_close_1d(self.a2_scale)):
213-
print_warning_once(
214-
"Found act_scales that are not equal for fp8 MoE layer. "
215-
"Using the maximum across experts for each layer. ")
216-
217-
self.a13_scale = nn.Parameter(self.a13_scale.max(),
218-
requires_grad=False)
219-
self.a2_scale = nn.Parameter(self.a2_scale.max(),
220-
requires_grad=False)
133+
quant_config=quant_config)
221134

222135
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
223-
num_tokens, hidden_size = hidden_states.shape
224-
hidden_states = hidden_states.view(-1, self.hidden_size)
225-
# router_logits: (num_tokens, n_experts)
136+
batch_size, sequence_length, hidden_dim = hidden_states.shape
137+
hidden_states = hidden_states.view(-1, hidden_dim)
138+
# router_logits: (batch * sequence_length, n_experts)
226139
router_logits, _ = self.gate(hidden_states)
227-
final_hidden_states = fused_moe(hidden_states,
228-
self.w13_weight,
229-
self.w2_weight,
230-
router_logits,
231-
self.top_k,
232-
renormalize=True,
233-
inplace=True,
234-
use_fp8=self.use_fp8,
235-
w1_scale=self.w13_scale,
236-
w2_scale=self.w2_scale,
237-
a1_scale=self.a13_scale,
238-
a2_scale=self.a2_scale)
239-
240-
if self.tp_size > 1:
241-
final_hidden_states = tensor_model_parallel_all_reduce(
242-
final_hidden_states)
243-
244-
return final_hidden_states.view(num_tokens, hidden_size)
140+
141+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
142+
routing_weights, selected_experts = torch.topk(routing_weights,
143+
self.top_k,
144+
dim=-1)
145+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
146+
147+
final_hidden_states = None
148+
for expert_idx in self.expert_indicies:
149+
expert_layer = self.experts[expert_idx]
150+
expert_mask = (selected_experts == expert_idx)
151+
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
152+
keepdim=True)
153+
154+
current_hidden_states = expert_layer(hidden_states).mul_(
155+
expert_weights)
156+
if final_hidden_states is None:
157+
final_hidden_states = current_hidden_states
158+
else:
159+
final_hidden_states.add_(current_hidden_states)
160+
161+
return tensor_model_parallel_all_reduce(final_hidden_states).view(
162+
batch_size, sequence_length, hidden_dim)
163+
164+
245165

246166

247167
class MixtralAttention(nn.Module):
@@ -341,12 +261,8 @@ def __init__(
341261
rope_theta=rope_theta,
342262
cache_config=cache_config,
343263
quant_config=quant_config)
344-
self.block_sparse_moe = MixtralMoE(
345-
num_experts=config.num_local_experts,
346-
top_k=config.num_experts_per_tok,
347-
hidden_size=config.hidden_size,
348-
intermediate_size=config.intermediate_size,
349-
quant_config=quant_config)
264+
self.block_sparse_moe = MixtralMoE(config=config,
265+
quant_config=quant_config)
350266
self.input_layernorm = RMSNorm(config.hidden_size,
351267
eps=config.rms_norm_eps)
352268
self.post_attention_layernorm = RMSNorm(config.hidden_size,

0 commit comments

Comments
 (0)