@@ -65,12 +65,6 @@ def __init__(self, wte=None) -> None:
6565 self .type_feature : str
6666 self .img_processor : CLIPVisionModel
6767
68- def set_img_features (self , img_features : torch .FloatTensor ) -> None :
69- self .img_features = img_features
70-
71- def set_img_sizes (self , img_sizes : torch .LongTensor ) -> None :
72- self .img_sizes = img_sizes
73-
7468 def get_img_features (self ,
7569 img_embeds : torch .FloatTensor ) -> torch .FloatTensor :
7670 LAYER_IDX = self .layer_idx
@@ -144,21 +138,16 @@ def __init__(self,
144138 self .layer_idx = config .img_processor .get ('layer_idx' , - 2 )
145139 self .type_feature = config .img_processor .get ('type_feature' , 'patch' )
146140
147- def forward (self ,
148- input_ids : torch .LongTensor ,
141+ def forward (self , input_ids : torch .LongTensor ,
149142 pixel_values : torch .FloatTensor ,
150- image_sizes = None ) -> torch .FloatTensor :
143+ image_sizes : torch . Tensor ) -> torch .FloatTensor :
151144 """process and merge text embeddings with image embeddings."""
152145
146+ # (batch_size, max_num_crops, 3, height, width)
153147 img_embeds = pixel_values
154- img_sizes = image_sizes
155148
156- if self .img_features is not None :
157- img_embeds = self .img_features .clone ()
158- self .img_features = None
159-
160- if self .img_sizes is not None :
161- img_sizes = self .img_sizes
149+ # (batch_size, 2)
150+ img_sizes = image_sizes
162151
163152 input_shape = input_ids .size ()
164153 input_ids = input_ids .view (- 1 , input_shape [- 1 ])
@@ -190,11 +179,8 @@ def forward(self,
190179 output_imgs = []
191180 output_len = []
192181
193- if isinstance (img_sizes , torch .Tensor ):
194- img_sizes .squeeze_ (0 )
195-
196182 for _bs in range (bs ):
197- h , w = img_sizes
183+ h , w = img_sizes [ _bs ]
198184 h = h // 336
199185 w = w // 336
200186 B_ = h * w
0 commit comments