|
23 | 23 | """Inference-only Mixtral model.""" |
24 | 24 | from typing import Iterable, List, Optional, Tuple |
25 | 25 |
|
| 26 | +import numpy as np |
| 27 | + |
26 | 28 | import torch |
27 | 29 | from torch import nn |
28 | 30 | from transformers import MixtralConfig |
|
37 | 39 | from vllm.model_executor.layers.layernorm import RMSNorm |
38 | 40 | from vllm.model_executor.layers.linear import (QKVParallelLinear, |
39 | 41 | ReplicatedLinear, |
| 42 | + LinearMethodBase, |
40 | 43 | RowParallelLinear) |
41 | 44 | from vllm.model_executor.layers.logits_processor import LogitsProcessor |
42 | 45 | from vllm.model_executor.layers.quantization.base_config import ( |
|
53 | 56 | from vllm.utils import print_warning_once |
54 | 57 |
|
55 | 58 |
|
56 | | -class MixtralMoE(nn.Module): |
57 | | - """A tensor-parallel MoE implementation for Mixtral that shards each expert |
58 | | - across all ranks. |
59 | | -
|
60 | | - Each expert's weights are sharded across all ranks and a fused MoE |
61 | | - kernel is used for the forward pass, and finally we reduce the outputs |
62 | | - across ranks. |
63 | | - """ |
| 59 | +class MixtralMLP(nn.Module): |
64 | 60 |
|
65 | 61 | def __init__( |
66 | 62 | self, |
67 | 63 | num_experts: int, |
68 | | - top_k: int, |
69 | 64 | hidden_size: int, |
70 | 65 | intermediate_size: int, |
71 | | - params_dtype: Optional[torch.dtype] = None, |
72 | | - tp_size: Optional[int] = None, |
73 | 66 | quant_config: Optional[QuantizationConfig] = None, |
74 | | - ): |
| 67 | + ) -> None: |
75 | 68 | super().__init__() |
76 | | - self.tp_size = tp_size or get_tensor_model_parallel_world_size() |
77 | | - self.num_total_experts = num_experts |
78 | | - self.top_k = top_k |
79 | | - self.hidden_size = hidden_size |
80 | | - self.intermediate_size = intermediate_size // self.tp_size |
81 | | - self.quant_config = quant_config |
| 69 | + self.num_experts = num_experts |
| 70 | + self.ffn_dim = intermediate_size |
| 71 | + self.hidden_dim = hidden_size |
| 72 | + |
| 73 | + self.w1 = ReplicatedLinear(self.hidden_dim, |
| 74 | + self.ffn_dim, |
| 75 | + bias=False, |
| 76 | + quant_config=quant_config) |
| 77 | + self.w2 = ReplicatedLinear(self.ffn_dim, |
| 78 | + self.hidden_dim, |
| 79 | + bias=False, |
| 80 | + quant_config=quant_config) |
| 81 | + self.w3 = ReplicatedLinear(self.hidden_dim, |
| 82 | + self.ffn_dim, |
| 83 | + bias=False, |
| 84 | + quant_config=quant_config) |
| 85 | + |
| 86 | + # TODO: Use vllm's SiluAndMul |
| 87 | + self.act_fn = nn.SiLU() |
| 88 | + |
| 89 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 90 | + w1_out, _ = self.w1(hidden_states) |
| 91 | + w1_out = self.act_fn(w1_out) |
| 92 | + w3_out, _ = self.w3(hidden_states) |
| 93 | + current_hidden_states = w1_out * w3_out |
| 94 | + current_hidden_states, _ = self.w2(current_hidden_states) |
| 95 | + return current_hidden_states |
82 | 96 |
|
83 | | - # FIXME(pcmoritz): Make this more general to support different |
84 | | - # quantization schemes |
85 | | - self.use_fp8 = isinstance(quant_config, Fp8Config) |
86 | 97 |
|
87 | | - if params_dtype is None: |
88 | | - params_dtype = torch.get_default_dtype() |
89 | | - self.params_dtype = params_dtype |
| 98 | +class MixtralMoE(nn.Module): |
90 | 99 |
|
91 | | - # Gate always runs at half / full precision for now. |
92 | | - self.gate = ReplicatedLinear(self.hidden_size, |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + config: MixtralConfig, |
| 103 | + quant_config: Optional[QuantizationConfig] = None, |
| 104 | + ): |
| 105 | + super().__init__() |
| 106 | + self.config = config |
| 107 | + self.rank = get_tensor_model_parallel_rank() |
| 108 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 109 | + self.num_total_experts = config.num_local_experts |
| 110 | + self.top_k = config.num_experts_per_tok |
| 111 | + if self.tp_size > self.num_total_experts: |
| 112 | + raise ValueError( |
| 113 | + f"Tensor parallel size {self.tp_size} is greater than " |
| 114 | + f"the number of experts {self.num_total_experts}.") |
| 115 | + # Split experts equally between ranks |
| 116 | + self.expert_indicies = np.array_split(range( |
| 117 | + self.num_total_experts), self.tp_size)[self.rank].tolist() |
| 118 | + if not self.expert_indicies: |
| 119 | + raise ValueError( |
| 120 | + f"Rank {self.rank} has no experts assigned to it.") |
| 121 | + |
| 122 | + self.experts = nn.ModuleList([ |
| 123 | + MixtralMLP(self.num_total_experts, |
| 124 | + config.hidden_size, |
| 125 | + config.intermediate_size, |
| 126 | + quant_config=quant_config) |
| 127 | + if idx in self.expert_indicies else None |
| 128 | + for idx in range(self.num_total_experts) |
| 129 | + ]) |
| 130 | + self.gate = ReplicatedLinear(config.hidden_size, |
93 | 131 | self.num_total_experts, |
94 | 132 | bias=False, |
95 | | - params_dtype=self.params_dtype, |
96 | | - quant_config=None) |
97 | | - |
98 | | - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: |
99 | | - params_dtype = torch.float8_e4m3fn |
100 | | - |
101 | | - self.w13_weight = nn.Parameter( |
102 | | - torch.empty(self.num_total_experts, |
103 | | - 2 * self.intermediate_size, |
104 | | - self.hidden_size, |
105 | | - dtype=params_dtype)) |
106 | | - self.w2_weight = nn.Parameter( |
107 | | - torch.empty(self.num_total_experts, |
108 | | - self.hidden_size, |
109 | | - self.intermediate_size, |
110 | | - dtype=params_dtype)) |
111 | | - |
112 | | - set_weight_attrs(self.w13_weight, { |
113 | | - "weight_loader": self.weight_loader, |
114 | | - }) |
115 | | - set_weight_attrs(self.w2_weight, { |
116 | | - "weight_loader": self.weight_loader, |
117 | | - }) |
118 | | - |
119 | | - # Used for fp8. |
120 | | - self.w13_scale = None |
121 | | - self.w2_scale = None |
122 | | - self.a13_scale = None |
123 | | - self.a2_scale = None |
124 | | - |
125 | | - if self.use_fp8: |
126 | | - # WEIGHT_SCALE (for fp8) |
127 | | - self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, |
128 | | - dtype=torch.float32), |
129 | | - requires_grad=False) |
130 | | - self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, |
131 | | - dtype=torch.float32), |
132 | | - requires_grad=False) |
133 | | - |
134 | | - # If loading fp8 checkpoint, pass the weight loaders. |
135 | | - # If loading an fp16 checkpoint, do not (we will quantize in |
136 | | - # process_weights_after_loading() |
137 | | - if quant_config.is_checkpoint_fp8_serialized: |
138 | | - set_weight_attrs(self.w13_scale, { |
139 | | - "weight_loader": self.weight_loader, |
140 | | - }) |
141 | | - set_weight_attrs(self.w2_scale, { |
142 | | - "weight_loader": self.weight_loader, |
143 | | - }) |
144 | | - |
145 | | - # ACT_SCALE (for fp8) |
146 | | - if quant_config.activation_scheme == "static": |
147 | | - if not quant_config.is_checkpoint_fp8_serialized: |
148 | | - raise ValueError( |
149 | | - "Found static activation scheme for checkpoint that " |
150 | | - "was not serialized fp8.") |
151 | | - self.a13_scale = nn.Parameter(torch.zeros( |
152 | | - self.num_total_experts, dtype=torch.float32), |
153 | | - requires_grad=False) |
154 | | - self.a2_scale = nn.Parameter(torch.zeros( |
155 | | - self.num_total_experts, dtype=torch.float32), |
156 | | - requires_grad=False) |
157 | | - |
158 | | - set_weight_attrs(self.a13_scale, { |
159 | | - "weight_loader": self.weight_loader, |
160 | | - }) |
161 | | - set_weight_attrs(self.a2_scale, { |
162 | | - "weight_loader": self.weight_loader, |
163 | | - }) |
164 | | - |
165 | | - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, |
166 | | - weight_name: str, expert_id: int): |
167 | | - tp_rank = get_tensor_model_parallel_rank() |
168 | | - param_data = param.data |
169 | | - shard_size = self.intermediate_size |
170 | | - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) |
171 | | - if weight_name.endswith("w1.weight"): |
172 | | - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] |
173 | | - if weight_name.endswith("w3.weight"): |
174 | | - param_data[expert_id, |
175 | | - shard_size:2 * shard_size, :] = loaded_weight[shard, :] |
176 | | - if weight_name.endswith("w2.weight"): |
177 | | - param_data[expert_id, :, :] = loaded_weight[:, shard] |
178 | | - if "act_scale" in weight_name or "weight_scale" in weight_name: |
179 | | - param_data[expert_id] = loaded_weight |
180 | | - |
181 | | - def process_weights_after_loading(self): |
182 | | - # Fp8 is the only case where we need to process after loading. |
183 | | - if not self.use_fp8: |
184 | | - return |
185 | | - |
186 | | - # If checkpoint is fp16, quantize here. |
187 | | - if not self.quant_config.is_checkpoint_fp8_serialized: |
188 | | - w13_weight = torch.empty_like(self.w13_weight.data, |
189 | | - dtype=torch.float8_e4m3fn) |
190 | | - w2_weight = torch.empty_like(self.w2_weight.data, |
191 | | - dtype=torch.float8_e4m3fn) |
192 | | - for expert in range(self.num_total_experts): |
193 | | - w13_weight[expert, :, :], self.w13_scale[ |
194 | | - expert] = ops.scaled_fp8_quant( |
195 | | - self.w13_weight.data[expert, :, :]) |
196 | | - w2_weight[expert, :, :], self.w2_scale[ |
197 | | - expert] = ops.scaled_fp8_quant( |
198 | | - self.w2_weight.data[expert, :, :]) |
199 | | - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) |
200 | | - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) |
201 | | - |
202 | | - # If checkpoint is fp8 + static, cleanup act_scales. |
203 | | - # Since state_dict has an act_scale per expert but our kernels |
204 | | - # are passed one act_scale shared across all experts. |
205 | | - elif self.quant_config.activation_scheme == "static": |
206 | | - if self.a13_scale is None or self.a2_scale is None: |
207 | | - raise ValueError( |
208 | | - "QuantConfig has static quantization, but found " |
209 | | - "activation scales are None.") |
210 | | - |
211 | | - if (not all_close_1d(self.a13_scale) |
212 | | - or not all_close_1d(self.a2_scale)): |
213 | | - print_warning_once( |
214 | | - "Found act_scales that are not equal for fp8 MoE layer. " |
215 | | - "Using the maximum across experts for each layer. ") |
216 | | - |
217 | | - self.a13_scale = nn.Parameter(self.a13_scale.max(), |
218 | | - requires_grad=False) |
219 | | - self.a2_scale = nn.Parameter(self.a2_scale.max(), |
220 | | - requires_grad=False) |
| 133 | + quant_config=quant_config) |
221 | 134 |
|
222 | 135 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
223 | | - num_tokens, hidden_size = hidden_states.shape |
224 | | - hidden_states = hidden_states.view(-1, self.hidden_size) |
225 | | - # router_logits: (num_tokens, n_experts) |
| 136 | + batch_size, sequence_length, hidden_dim = hidden_states.shape |
| 137 | + hidden_states = hidden_states.view(-1, hidden_dim) |
| 138 | + # router_logits: (batch * sequence_length, n_experts) |
226 | 139 | router_logits, _ = self.gate(hidden_states) |
227 | | - final_hidden_states = fused_moe(hidden_states, |
228 | | - self.w13_weight, |
229 | | - self.w2_weight, |
230 | | - router_logits, |
231 | | - self.top_k, |
232 | | - renormalize=True, |
233 | | - inplace=True, |
234 | | - use_fp8=self.use_fp8, |
235 | | - w1_scale=self.w13_scale, |
236 | | - w2_scale=self.w2_scale, |
237 | | - a1_scale=self.a13_scale, |
238 | | - a2_scale=self.a2_scale) |
239 | | - |
240 | | - if self.tp_size > 1: |
241 | | - final_hidden_states = tensor_model_parallel_all_reduce( |
242 | | - final_hidden_states) |
243 | | - |
244 | | - return final_hidden_states.view(num_tokens, hidden_size) |
| 140 | + |
| 141 | + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| 142 | + routing_weights, selected_experts = torch.topk(routing_weights, |
| 143 | + self.top_k, |
| 144 | + dim=-1) |
| 145 | + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| 146 | + |
| 147 | + final_hidden_states = None |
| 148 | + for expert_idx in self.expert_indicies: |
| 149 | + expert_layer = self.experts[expert_idx] |
| 150 | + expert_mask = (selected_experts == expert_idx) |
| 151 | + expert_weights = (routing_weights * expert_mask).sum(dim=-1, |
| 152 | + keepdim=True) |
| 153 | + |
| 154 | + current_hidden_states = expert_layer(hidden_states).mul_( |
| 155 | + expert_weights) |
| 156 | + if final_hidden_states is None: |
| 157 | + final_hidden_states = current_hidden_states |
| 158 | + else: |
| 159 | + final_hidden_states.add_(current_hidden_states) |
| 160 | + |
| 161 | + return tensor_model_parallel_all_reduce(final_hidden_states).view( |
| 162 | + batch_size, sequence_length, hidden_dim) |
| 163 | + |
| 164 | + |
245 | 165 |
|
246 | 166 |
|
247 | 167 | class MixtralAttention(nn.Module): |
@@ -341,12 +261,8 @@ def __init__( |
341 | 261 | rope_theta=rope_theta, |
342 | 262 | cache_config=cache_config, |
343 | 263 | quant_config=quant_config) |
344 | | - self.block_sparse_moe = MixtralMoE( |
345 | | - num_experts=config.num_local_experts, |
346 | | - top_k=config.num_experts_per_tok, |
347 | | - hidden_size=config.hidden_size, |
348 | | - intermediate_size=config.intermediate_size, |
349 | | - quant_config=quant_config) |
| 264 | + self.block_sparse_moe = MixtralMoE(config=config, |
| 265 | + quant_config=quant_config) |
350 | 266 | self.input_layernorm = RMSNorm(config.hidden_size, |
351 | 267 | eps=config.rms_norm_eps) |
352 | 268 | self.post_attention_layernorm = RMSNorm(config.hidden_size, |
|
0 commit comments