11"""Minimal implementation of CLIPVisionModel intended to be only used
22within a vision language model."""
3- from typing import Iterable , List , Optional , Tuple , Union
3+ from typing import Iterable , List , Optional , Set , Tuple , Union
44
55import numpy as np
66import torch
@@ -483,14 +483,16 @@ def device(self):
483483
484484 # (TODO) Add prefix argument for filtering out weights to be loaded
485485 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
486- def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
486+ def load_weights (self , weights : Iterable [Tuple [str ,
487+ torch .Tensor ]]) -> Set [str ]:
487488 stacked_params_mapping = [
488489 # (param_name, shard_name, shard_id)
489490 ("qkv_proj" , "q_proj" , "q" ),
490491 ("qkv_proj" , "k_proj" , "k" ),
491492 ("qkv_proj" , "v_proj" , "v" ),
492493 ] if self .shard_weight else []
493494 params_dict = dict (self .named_parameters ())
495+ loaded_params : Set [str ] = set ()
494496 layer_count = len (self .vision_model .encoder .layers )
495497
496498 for name , loaded_weight in weights :
@@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
508510 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
509511 if weight_name not in name :
510512 continue
513+ name = name .replace (weight_name , param_name )
511514
512- param = params_dict [name . replace ( weight_name , param_name ) ]
515+ param = params_dict [name ]
513516 weight_loader = param .weight_loader
514517 weight_loader (param , loaded_weight , shard_id )
515518 break
@@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
518521 weight_loader = getattr (param , "weight_loader" ,
519522 default_weight_loader )
520523 weight_loader (param , loaded_weight )
524+ loaded_params .add (name )
525+ return loaded_params
0 commit comments