Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/mlc_llm/conversation_template/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@
stop_token_ids=[32000, 32001, 32007],
)
)

# Phi-3-vision
ConvTemplateRegistry.register_conv_template(
Conversation(
name="phi-3-vision",
system_template=f"{MessagePlaceholders.SYSTEM.value}",
system_message="\n",
roles={"user": "<|user|>", "assistant": "<|assistant|>"},
seps=["<|end|>\n"],
role_content_sep="\n",
role_empty_sep="\n",
system_prefix_token_ids=[1],
stop_str=["<|endoftext|>"],
stop_token_ids=[2, 32000, 32001, 32007],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"custom", # for web-llm only
"phi-2",
"phi-3",
"phi-3-vision",
"stablelm-2",
"gemma_instruction",
"orion",
Expand Down
245 changes: 9 additions & 236 deletions python/mlc_llm/model/llava/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@

import dataclasses
import logging
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional

from tvm import relax, tir
from tvm import tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Module, Tensor
from tvm.relax.frontend.nn.modules import Conv2D
from tvm.relax.frontend.nn.op import (
broadcast_to,
concat,
matmul,
permute_dims,
reshape,
softmax,
wrap_nested,
)
from tvm.relax.op import arange, strided_slice
from tvm.relax.frontend.nn.op import reshape, wrap_nested
from tvm.relax.op import strided_slice

from mlc_llm import op as op_ext
from mlc_llm.model.model_preset import MODEL_PRESETS
from mlc_llm.model.vision import CLIPVisionConfig, CLIPVisionModel
from mlc_llm.nn import PagedKVCache, RopeMode

from ...support.config import ConfigBase
Expand All @@ -33,25 +25,6 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class LlavaVisionConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
"""
Config for the vision encoder
"""

hidden_size: int
image_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
patch_size: int
projection_dim: int
vocab_size: int
num_channels: int = 3
layer_norm_eps: float = 1e-06
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)


CONFIG_MAP = {"LlamaForCausalLM": LlamaConfig, "MistralForCausalLM": MistralConfig}
ARCHITECTURE_MAP = {"LlamaForCausalLM": LlamaForCasualLM, "MistralForCausalLM": MistralForCasualLM}

Expand All @@ -64,7 +37,7 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

image_token_index: int
text_config: LlamaConfig
vision_config: LlavaVisionConfig
vision_config: CLIPVisionConfig
vocab_size: int
context_window_size: int = -1
sliding_window_size: int = -1
Expand All @@ -76,15 +49,15 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes

def __post_init__(self) -> None:
vision_config_dict: Dict[str, Any]
if isinstance(self.vision_config, LlavaVisionConfig):
if isinstance(self.vision_config, CLIPVisionConfig):
vision_config_dict = dataclasses.asdict(self.vision_config)
else:
vision_config_dict = dict(self.vision_config)

for k, v in vision_config_dict.pop("kwargs", {}).items():
vision_config_dict[k] = v

self.vision_config = LlavaVisionConfig.from_dict(vision_config_dict)
self.vision_config = CLIPVisionConfig.from_dict(vision_config_dict)

text_config_dict: Dict[str, Any]
if isinstance(self.text_config, ConfigBase):
Expand Down Expand Up @@ -139,207 +112,7 @@ def get_hf_config(self, text_config_dict: Dict[str, Any]) -> Dict[str, Any]:
return hf_config


# pylint: disable=missing-docstring


class CLIPVisionEmbeddings(Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter((self.embed_dim,))
self.patch_embedding = Conv2D(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(num=self.num_positions, dim=self.embed_dim)

def forward(self, pixel_values: Tensor) -> Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = reshape(patch_embeds, shape=(batch_size, self.embed_dim, -1))
patch_embeds = permute_dims(
patch_embeds, axes=(0, 2, 1)
) # shape = [batch,grid*grid,embed_dim]
class_embeds = broadcast_to(
self.class_embedding, shape=(batch_size, 1, self.embed_dim)
) # shape of (batch,1,embed_dim)
embeddings = concat([class_embeds, patch_embeds], dim=1)

posi_ids = reshape(
wrap_nested(arange(0, self.num_positions, dtype="int32"), name="arange"), shape=(1, -1)
)
batch_position_embedding = broadcast_to(
self.position_embedding(posi_ids),
shape=(batch_size, self.num_positions, self.embed_dim),
)
embeddings = embeddings + batch_position_embedding
return embeddings


def sigmoid(x: Tensor, name: str = "sigmoid") -> Tensor:
"""Sigmoid of a Tensor

Parameters
----------
x : Tensor
Input tensor to expand.
name : str
Name hint for this operator.

Returns
-------
result : Tensor
Sigmoid result.
"""
return wrap_nested(relax.op.sigmoid(x._expr), name) # pylint: disable=protected-access


class LlavaQuickGELU(Module):
def forward(self, input_tensor: Tensor) -> Tensor:
return input_tensor * sigmoid(input_tensor * 1.702)


class CLIPMLP(Module):
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.activation_fn = LlavaQuickGELU()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states


class CLIPAttention(Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
reshape_tensor = reshape(tensor, shape=(bsz, seq_len, self.num_heads, self.head_dim))
permute_tensor = permute_dims(reshape_tensor, axes=(0, 2, 1, 3))
return permute_tensor

def forward(
self,
hidden_states: Tensor,
) -> Tensor:
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self._shape(self.q_proj(hidden_states) * self.scale, tgt_len, bsz)
key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)

proj_shape = (
bsz * self.num_heads,
-1,
self.head_dim,
) # shape of (batch*num_heads, seq_len,head_dim)

query_states = reshape(query_states, shape=proj_shape)
key_states = reshape(key_states, shape=proj_shape)
value_states = reshape(value_states, shape=proj_shape)

trans_key_states = permute_dims(key_states, axes=(0, 2, 1))

attn_weights = matmul(query_states, trans_key_states)
attn_weights = softmax(attn_weights, axis=-1)
attn_output = matmul(attn_weights, value_states)
attn_output = reshape(attn_output, shape=(bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = permute_dims(attn_output, axes=(0, 2, 1, 3))
attn_output = reshape(attn_output, shape=(bsz, tgt_len, embed_dim))
attn_output = self.out_proj(attn_output)

return attn_output


class CLIPEncoderLayer(Module):
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config)
self.layer_norm2 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps)

def forward(self, hidden_states: Tensor) -> Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
return outputs


class CLIPEncoder(Module):
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.layers = nn.ModuleList(
[CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)

def forward(self, inputs_embeds: Tensor) -> Tensor:
hidden_states = inputs_embeds
encoder_states: Tuple[Any, ...] = ()
for _, encoder_layer in enumerate(self.layers):
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs[0]
encoder_states = encoder_states + (hidden_states,)
return encoder_states


class CLIPVisionTransformer(Module):
def __init__(self, config: LlavaVisionConfig):
super().__init__()
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

def forward(self, pixel_values: Tensor) -> Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs


class CLIPVisionModel(Module):
def __init__(self, config: LlavaVisionConfig):
super().__init__()
self.vision_model = CLIPVisionTransformer(config)

def forward(self, pixel_values: Tensor) -> Tensor:
return self.vision_model(pixel_values)[-2]
# pylint: disable=invalid-name,missing-docstring


class LlavaMultiModalProjector(nn.Module):
Expand Down
15 changes: 15 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .orion import orion_loader, orion_model, orion_quantization
from .phi import phi_loader, phi_model, phi_quantization
from .phi3 import phi3_loader, phi3_model, phi3_quantization
from .phi3v import phi3v_loader, phi3v_model, phi3v_quantization
from .qwen import qwen_loader, qwen_model, qwen_quantization
from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization
from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization
Expand Down Expand Up @@ -220,6 +221,20 @@ class Model:
"ft-quant": phi3_quantization.ft_quant,
},
),
"phi3_v": Model(
name="phi3_v",
model=phi3v_model.Phi3VForCausalLM,
config=phi3v_model.Phi3VConfig,
source={
"huggingface-torch": phi3v_loader.huggingface,
"huggingface-safetensor": phi3v_loader.huggingface,
},
quantize={
"no-quant": phi3v_quantization.no_quant,
"group-quant": phi3v_quantization.group_quant,
"ft-quant": phi3v_quantization.ft_quant,
},
),
"qwen": Model(
name="qwen",
model=qwen_model.QWenLMHeadModel,
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/phi3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Common `nn.Modules` used to define LLMs in this project."""


from .phi3_model import Phi3Model
Loading