diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 25b8b69e081b..4ec2b683fc33 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -505,30 +505,47 @@ def tensor_parallel(self): Apply the model's tensor parallelization plan. Currently only supports linear layers. """ - tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {} + # Look for tp plans in all of the PreTrainedModels found in self.model + is_pretrained_model = lambda m: isinstance(m, PreTrainedModel) + supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None + pretrained_models = filter(is_pretrained_model, self.model.modules()) + models_with_tp_plan = filter(supports_tp_plan, pretrained_models) - if not tp_plan and self.tp_size > 1: + if not any(models_with_tp_plan) and self.tp_size > 1: raise ValueError( f"{type(self.model)} does not support tensor parallel yet!") - # Some weight loaders expect linear layers to inherit from vLLM's - # LinearBase class, so we set a default style which causes any - # unspecified linear layers to be replaced with ReplicatedLinear - tp_plan[".*"] = "replicate" - - def _tensor_parallel(module: nn.Module, prefix: str = ""): + def _tensor_parallel(module: nn.Module, + prefix: str = "", + tp_plan=None): + tp_plan = tp_plan or {} + + # If the current module is a PreTrainedModel, set the tp_plan for + # all of its children + if isinstance(module, PreTrainedModel): + tp_plan = module.config.base_model_tp_plan or {} + tp_plan = { + maybe_prefix(prefix, k): v + for k, v in tp_plan.items() + } + + # Some weight loaders expect linear layers to inherit from vLLM's + # LinearBase class, so we set a default style which causes any + # unspecified linear layers to be replaced with ReplicatedLinear for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) - for pattern, style in tp_plan.items(): - if re.match(pattern, qual_name) and isinstance( - child_module, nn.Linear): - new_module = replace_linear_class( - child_module, style, self.quant_config) - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) - break + if isinstance(child_module, nn.Linear): + generator = (p for p in tp_plan if re.match(p, qual_name)) + pattern = next(generator, None) + style = tp_plan.get(pattern, "replicate") + new_module = replace_linear_class(child_module, style, + self.quant_config) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) else: - _tensor_parallel(child_module, prefix=qual_name) + _tensor_parallel(child_module, + prefix=qual_name, + tp_plan=tp_plan) _tensor_parallel(self.model)