2323# See the License for the specific language governing permissions and
2424# limitations under the License.
2525"""Inference-only Mixtral model."""
26- from collections .abc import Iterable
26+ import typing
27+ from collections .abc import Callable , Iterable
2728from itertools import islice
2829from typing import Optional , Union
2930
3334
3435from vllm .attention import Attention
3536from vllm .compilation .decorators import support_torch_compile
36- from vllm .config import CacheConfig , VllmConfig
37- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
37+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
38+ from vllm .distributed import (get_ep_group , get_pp_group ,
39+ get_tensor_model_parallel_world_size )
3840from vllm .model_executor .layers .fused_moe import FusedMoE
3941from vllm .model_executor .layers .layernorm import RMSNorm
4042from vllm .model_executor .layers .linear import (QKVParallelLinear ,
5052from vllm .model_executor .sampling_metadata import SamplingMetadata
5153from vllm .sequence import IntermediateTensors
5254
53- from .interfaces import SupportsLoRA , SupportsPP
54- from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
55+ from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
56+ from .utils import (AutoWeightsLoader , PPMissingLayer , is_pp_missing_parameter ,
5557 make_empty_intermediate_tensors_factory , make_layers ,
5658 maybe_prefix )
5759
@@ -74,10 +76,32 @@ def __init__(self,
7476 quant_config : Optional [QuantizationConfig ] = None ,
7577 tp_size : Optional [int ] = None ,
7678 dp_size : Optional [int ] = None ,
77- prefix : str = "" ):
79+ prefix : str = "" ,
80+ enable_eplb : bool = False ):
7881 super ().__init__ ()
7982 self .hidden_size = hidden_size
8083
84+ self .ep_group = get_ep_group ().device_group
85+ self .ep_rank = self .ep_group .rank ()
86+ self .ep_size = self .ep_group .size ()
87+
88+ # Expert Parallelism Load balancing settings.
89+ vllm_config = get_current_vllm_config ()
90+ parallel_config = vllm_config .parallel_config
91+ self .enable_eplb = enable_eplb
92+
93+ self .n_routed_experts = num_experts
94+ self .n_logical_experts = num_experts
95+ self .n_redundant_experts = (
96+ parallel_config .eplb_config .num_redundant_experts )
97+ self .n_physical_experts = (self .n_logical_experts +
98+ self .n_redundant_experts )
99+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
100+ self .physical_expert_start = (self .ep_rank *
101+ self .n_local_physical_experts )
102+ self .physical_expert_end = (self .physical_expert_start +
103+ self .n_local_physical_experts )
104+
81105 # Gate always runs at half / full precision for now.
82106
83107 self .gate = ReplicatedLinear (hidden_size ,
@@ -97,7 +121,9 @@ def __init__(self,
97121 quant_config = quant_config ,
98122 tp_size = tp_size ,
99123 dp_size = dp_size ,
100- prefix = f"{ prefix } .experts" )
124+ prefix = f"{ prefix } .experts" ,
125+ enable_eplb = self .enable_eplb ,
126+ num_redundant_experts = self .n_redundant_experts )
101127
102128 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
103129 # NOTE: hidden_states can have either 1D or 2D shape.
@@ -200,6 +226,7 @@ def __init__(
200226 cache_config : Optional [CacheConfig ] = None ,
201227 quant_config : Optional [QuantizationConfig ] = None ,
202228 prefix : str = "" ,
229+ enable_eplb : bool = False ,
203230 ) -> None :
204231 super ().__init__ ()
205232 self .hidden_size = config .hidden_size
@@ -221,7 +248,8 @@ def __init__(
221248 hidden_size = config .hidden_size ,
222249 intermediate_size = config .intermediate_size ,
223250 quant_config = quant_config ,
224- prefix = f"{ prefix } .block_sparse_moe" )
251+ prefix = f"{ prefix } .block_sparse_moe" ,
252+ enable_eplb = enable_eplb )
225253 self .input_layernorm = RMSNorm (config .hidden_size ,
226254 eps = config .rms_norm_eps )
227255 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
@@ -262,6 +290,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
262290 cache_config = vllm_config .cache_config
263291 quant_config = vllm_config .quant_config
264292 lora_config = vllm_config .lora_config
293+ parallel_config = vllm_config .parallel_config
265294
266295 self .config = config
267296 self .quant_config = quant_config
@@ -276,10 +305,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
276305 org_num_embeddings = config .vocab_size ,
277306 )
278307
308+ self .enable_eplb = parallel_config .enable_eplb
309+ self .num_redundant_experts = (
310+ parallel_config .eplb_config .num_redundant_experts )
311+
279312 self .start_layer , self .end_layer , self .layers = make_layers (
280313 config .num_hidden_layers ,
281314 lambda prefix : MixtralDecoderLayer (
282- config , cache_config , quant_config = quant_config , prefix = prefix
315+ config ,
316+ cache_config ,
317+ quant_config = quant_config ,
318+ prefix = prefix ,
319+ enable_eplb = self .enable_eplb ,
283320 ),
284321 prefix = f"{ prefix } .layers" )
285322
@@ -325,7 +362,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
325362 ckpt_gate_proj_name = "w1" ,
326363 ckpt_down_proj_name = "w2" ,
327364 ckpt_up_proj_name = "w3" ,
328- num_experts = self .config .num_local_experts )
365+ num_experts = self .config .num_local_experts ,
366+ num_redundant_experts = self .num_redundant_experts )
329367
330368 def load_weights (self , weights : Iterable [tuple [str ,
331369 torch .Tensor ]]) -> set [str ]:
@@ -373,26 +411,40 @@ def load_weights(self, weights: Iterable[tuple[str,
373411 weight_loader (param , loaded_weight , shard_id )
374412 break
375413 else :
414+ is_expert_weight = False
376415 for mapping in expert_params_mapping :
377416 param_name , weight_name , expert_id , shard_id = mapping
417+
378418 if weight_name not in name :
379419 continue
380- name = name .replace (weight_name , param_name )
420+
421+ is_expert_weight = True
422+ name_mapped = name .replace (weight_name , param_name )
423+
381424 # Skip layers on other devices.
382- if is_pp_missing_parameter (name , self ):
425+ if is_pp_missing_parameter (name_mapped , self ):
383426 continue
384- if ((name .endswith (".bias" ) or name .endswith ("_bias" ))
385- and name not in params_dict ):
427+
428+ if ((name_mapped .endswith (".bias" )
429+ or name_mapped .endswith ("_bias" ))
430+ and name_mapped not in params_dict ):
386431 continue
387- param = params_dict [name ]
388- weight_loader = param .weight_loader
389- weight_loader (param ,
390- loaded_weight ,
391- name ,
392- shard_id = shard_id ,
393- expert_id = expert_id )
394- break
432+
433+ param = params_dict [name_mapped ]
434+ weight_loader = typing .cast (Callable [..., bool ],
435+ param .weight_loader )
436+ success = weight_loader (param ,
437+ loaded_weight ,
438+ name_mapped ,
439+ shard_id = shard_id ,
440+ expert_id = expert_id ,
441+ return_success = True )
442+ if success :
443+ name = name_mapped
444+ break
395445 else :
446+ if is_expert_weight :
447+ continue
396448 # Skip loading extra bias for GPTQ models.
397449 if ((name .endswith (".bias" ) or name .endswith ("_bias" ))
398450 and name not in params_dict ):
@@ -413,7 +465,8 @@ def load_weights(self, weights: Iterable[tuple[str,
413465 return loaded_params
414466
415467
416- class MixtralForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
468+ class MixtralForCausalLM (nn .Module , SupportsLoRA , SupportsPP ,
469+ MixtureOfExperts ):
417470 fall_back_to_pt_during_load = False
418471
419472 packed_modules_mapping = {
@@ -462,6 +515,67 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
462515 self .make_empty_intermediate_tensors = (
463516 self .model .make_empty_intermediate_tensors )
464517
518+ self .expert_weights = []
519+ self .moe_layers : list [FusedMoE ] = []
520+ example_moe = None
521+
522+ for layer in self .model .layers :
523+ if isinstance (layer , PPMissingLayer ):
524+ continue
525+ assert isinstance (layer , MixtralDecoderLayer )
526+ if hasattr (layer , "block_sparse_moe" ) and isinstance (
527+ layer .block_sparse_moe , MixtralMoE ):
528+ example_moe = layer .block_sparse_moe
529+ self .moe_layers .append (layer .block_sparse_moe .experts )
530+
531+ self .num_moe_layers = len (self .moe_layers )
532+
533+ if example_moe is None :
534+ raise RuntimeError ("No MixtralMoE layer found in model.layers." )
535+
536+ self .num_logical_experts = example_moe .n_logical_experts
537+ self .num_physical_experts = example_moe .n_physical_experts
538+ self .num_local_physical_experts = example_moe .n_local_physical_experts
539+ self .num_routed_experts = example_moe .n_routed_experts
540+ self .num_redundant_experts = example_moe .n_redundant_experts
541+ self .num_expert_groups = 1
542+ self .num_shared_experts = 0
543+
544+ def set_eplb_state (
545+ self ,
546+ expert_load_view : torch .Tensor ,
547+ logical_to_physical_map : torch .Tensor ,
548+ logical_replica_count : torch .Tensor ,
549+ ) -> None :
550+ for layer_idx , layer in enumerate (self .moe_layers ):
551+ # Register the expert weights.
552+ self .expert_weights .append (layer .get_expert_weights ())
553+ layer .set_eplb_state (
554+ moe_layer_idx = layer_idx ,
555+ expert_load_view = expert_load_view ,
556+ logical_to_physical_map = logical_to_physical_map ,
557+ logical_replica_count = logical_replica_count ,
558+ )
559+
560+ def update_physical_experts_metadata (
561+ self ,
562+ num_physical_experts : int ,
563+ num_local_physical_experts : int ,
564+ ) -> None :
565+ assert self .num_local_physical_experts == num_local_physical_experts
566+ self .num_physical_experts = num_physical_experts
567+ self .num_local_physical_experts = num_local_physical_experts
568+ self .num_redundant_experts = (num_physical_experts -
569+ self .num_logical_experts )
570+ for layer in self .model .layers :
571+ if hasattr (layer , "block_sparse_moe" ) and isinstance (
572+ layer .block_sparse_moe , MixtralMoE ):
573+ moe = layer .block_sparse_moe
574+ moe .n_local_physical_experts = num_local_physical_experts
575+ moe .n_physical_experts = num_physical_experts
576+ moe .n_redundant_experts = self .num_redundant_experts
577+ moe .experts .update_expert_map ()
578+
465579 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
466580 return self .model .get_input_embeddings (input_ids )
467581
0 commit comments