@@ -124,7 +124,7 @@ def _load_param(
124124 base_prefix : str ,
125125 param : nn .Parameter ,
126126 weights : Iterable [Tuple [str , torch .Tensor ]],
127- ) -> None :
127+ ) -> Iterable [ str ] :
128128 for weight_name , weight_data in weights :
129129 weight_qualname = self ._get_qualname (base_prefix , weight_name )
130130
@@ -143,12 +143,14 @@ def _load_param(
143143 default_weight_loader )
144144 weight_loader (param , weight_data )
145145
146+ yield weight_qualname
147+
146148 def _load_module (
147149 self ,
148150 base_prefix : str ,
149151 module : nn .Module ,
150152 weights : Iterable [Tuple [str , torch .Tensor ]],
151- ) -> None :
153+ ) -> Iterable [ str ] :
152154 if isinstance (module , PPMissingLayer ):
153155 return
154156
@@ -170,11 +172,13 @@ def _load_module(
170172 continue
171173
172174 if child_prefix in child_modules :
173- self ._load_module (prefix , child_modules [child_prefix ],
174- child_weights )
175+ yield from self ._load_module (prefix ,
176+ child_modules [child_prefix ],
177+ child_weights )
175178 elif child_prefix in child_params :
176- self ._load_param (prefix , child_params [child_prefix ],
177- child_weights )
179+ yield from self ._load_param (prefix ,
180+ child_params [child_prefix ],
181+ child_weights )
178182 else :
179183 if not self ._can_ignore_unexpected (prefix ):
180184 msg = (f"There is no module or parameter named '{ prefix } ' "
@@ -186,11 +190,12 @@ def load_weights(
186190 weights : Iterable [Tuple [str , torch .Tensor ]],
187191 * ,
188192 mapper : Optional [WeightsMapper ] = None ,
189- ) -> None :
193+ ) -> List [ str ] :
190194 if mapper is not None :
191195 weights = mapper .apply (weights )
192196
193- self ._load_module ("" , self .module , weights )
197+ autoloaded_weights = list (self ._load_module ("" , self .module , weights ))
198+ return autoloaded_weights
194199
195200
196201def init_vllm_registered_model (
0 commit comments