Skip to content

Commit cd11ccb

Browse files
committed
Fix regular Phi3V loading
1 parent 11fbbd4 commit cd11ccb

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

vllm/model_executor/models/phi3v.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,4 +710,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
710710
})
711711

712712
loader = AutoWeightsLoader(self)
713-
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
713+
autoloaded_weights = loader.load_weights(weights,
714+
mapper=hf_to_vllm_mapper)
715+
716+
# The HF config doesn't specify whether these are tied,
717+
# so we detect it this way
718+
if "embed_tokens" not in autoloaded_weights:
719+
self.embed_tokens = self.language_model.model.embed_tokens

vllm/model_executor/models/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _load_param(
124124
base_prefix: str,
125125
param: nn.Parameter,
126126
weights: Iterable[Tuple[str, torch.Tensor]],
127-
) -> None:
127+
) -> Iterable[str]:
128128
for weight_name, weight_data in weights:
129129
weight_qualname = self._get_qualname(base_prefix, weight_name)
130130

@@ -143,12 +143,14 @@ def _load_param(
143143
default_weight_loader)
144144
weight_loader(param, weight_data)
145145

146+
yield weight_qualname
147+
146148
def _load_module(
147149
self,
148150
base_prefix: str,
149151
module: nn.Module,
150152
weights: Iterable[Tuple[str, torch.Tensor]],
151-
) -> None:
153+
) -> Iterable[str]:
152154
if isinstance(module, PPMissingLayer):
153155
return
154156

@@ -170,11 +172,13 @@ def _load_module(
170172
continue
171173

172174
if child_prefix in child_modules:
173-
self._load_module(prefix, child_modules[child_prefix],
174-
child_weights)
175+
yield from self._load_module(prefix,
176+
child_modules[child_prefix],
177+
child_weights)
175178
elif child_prefix in child_params:
176-
self._load_param(prefix, child_params[child_prefix],
177-
child_weights)
179+
yield from self._load_param(prefix,
180+
child_params[child_prefix],
181+
child_weights)
178182
else:
179183
if not self._can_ignore_unexpected(prefix):
180184
msg = (f"There is no module or parameter named '{prefix}' "
@@ -186,11 +190,12 @@ def load_weights(
186190
weights: Iterable[Tuple[str, torch.Tensor]],
187191
*,
188192
mapper: Optional[WeightsMapper] = None,
189-
) -> None:
193+
) -> List[str]:
190194
if mapper is not None:
191195
weights = mapper.apply(weights)
192196

193-
self._load_module("", self.module, weights)
197+
autoloaded_weights = list(self._load_module("", self.module, weights))
198+
return autoloaded_weights
194199

195200

196201
def init_vllm_registered_model(

0 commit comments

Comments
 (0)