Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 118 additions & 119 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -71,9 +72,8 @@

class Phi3ImageEmbeddingBase(nn.Module):

def __init__(self, wte=None) -> None:
def __init__(self) -> None:
super().__init__()
self.wte = wte
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
Expand All @@ -100,10 +100,9 @@ def get_img_features(self,
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""

def __init__(self, config: PretrainedConfig, wte=None) -> None:
super().__init__(wte)
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()

self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
Expand Down Expand Up @@ -149,118 +148,115 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)

self.vocab_size = config.vocab_size
self.type_feature = config.img_processor.get('type_feature', 'patch')

def forward(self, input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""

# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values

# (batch_size, 2)
img_sizes = image_sizes

input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])

positions = torch.nonzero(input_ids == self.image_token_id)

select = False

target_dtype = self.img_projection[0].bias.dtype

if len(positions.tolist()) > 0:
# if self.use_hd_transform and img_sizes:
# img_embeds: (num_images, max_num_crops, 3, H, W)
# img_sizes: (num_images, 2).view(1, -1)

bs = img_embeds.shape[0]
# Nx(HW)xC
img_features = self.get_img_features(img_embeds.flatten(0, 1))
base_feat_height = base_feat_width = int(
img_features.shape[1]**0.5)

# bs x max_num_crops x (24x24) x C
img_features = img_features.view(
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
C = self.image_dim_out
H = base_feat_height

output_imgs = []
output_len = []

for _bs in range(bs):
h, w = img_sizes[_bs]
h = h // 336
w = w // 336
B_ = h * w

# 1 x (24x24) x 1024
global_img_feature = img_features[_bs, :1]

# 1 x 12 x 12 x 4096
glb_img = global_img_feature \
.reshape(1, H // 2, 2, H // 2, 2,C) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, H // 2, H // 2, 4 * C)
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)

# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN],
dim=2).reshape(1, -1, 4 * C)

# (max_num_crops-1) x (12x12) x C
sub_img = img_features[_bs, 1:]
# 16x574x1024
# get rid of padding sub_img
sub_img = sub_img[:B_]

sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, h * 12, w * 12, 4 * C)
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN],
dim=2).reshape(1, -1, 4 * C)
# (1, num_img_tokens, 1024*4)

# glb + sub
if self.hd_transform_order == 'glb_sub':
output_imgs.append(
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
elif self.hd_transform_order == 'sub_glb':
output_imgs.append(
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))

temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
output_len.append(temp_len)

num_img_tokens = output_len
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_dtype))
img_set_tensor.append(img_feature_proj)
select = True

input_ids.clamp_min_(0).clamp_max_(self.vocab_size)

hidden_states = self.wte(input_ids)

if select:
idx = 0
for i, cnt in enumerate(num_img_tokens):
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.dtype))
idx += cnt

return hidden_states.squeeze(0)
"""
process image and return vision embeddings.

pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
"""
num_images, num_crops, c, h, w = pixel_values.shape
pixel_values = pixel_values.flatten(0, 1)
img_features = self.get_img_features(pixel_values)
img_features = img_features.reshape(num_images, num_crops, -1,
self.image_dim_out)
image_features_proj = self.hd_feature_transform(
img_features, image_sizes)
return image_features_proj

def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
assert (
self.hd_transform_order == 'sub_glb'
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
if isinstance(self.img_projection, nn.Sequential):
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
else: # It's a single nn.Linear layer
target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype

global_image_features = image_features[:,
0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(
global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(
global_image_features_hd)

all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop

# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1:1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(
sub_image_features_hd)

# [sub features, separator, global features]
all_image_embeddings.append(
torch.cat([
sub_image_features_hd_newline.squeeze(
0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]))

image_features_proj = self.img_projection(
torch.stack(all_image_embeddings).to(target_device, target_dtype)
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)

return image_features_proj

def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
where h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2,
-1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2,
4 * C) # n_img, h_crop*12, w_crop*12, 4096
)
return image_features_hd

def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1,
-1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat(
[image_features_hd, newline_embeddings],
dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline


class Phi3VImagePixelInputs(TypedDict):
Expand Down Expand Up @@ -458,12 +454,12 @@ def __init__(self,

self.config = config
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID

self.model = LlamaModel(config, cache_config, quant_config)

# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, self.model.embed_tokens)
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand Down Expand Up @@ -530,9 +526,12 @@ def forward(self,
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
inputs_embeds = self.vision_embed_tokens(
input_ids, image_input["data"], image_input["image_sizes"])

vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
Expand Down