Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
ExaonePolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -129,6 +130,12 @@ def build_hf_engine(path: str,
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2_moe":
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "exaone4":
# Ensure we're using the correct version of transformers for EXAONE 4.0
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.54.0"), \
f"EXAONE 4.0 requires transformers >= 4.54.0, you have version {transformers.__version__}"
policy = ExaonePolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .qwen import *
from .qwen_v2 import *
from .qwen_v2_moe import *
from .exaone import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .container import ExaoneTransformerContainer, ExaoneNonTransformerContainer
from .model import ExaoneInferenceModel
from .policy import ExaonePolicy
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors for EXAONE 4.0

from ..common_parameters import *
from ..layer_container_base import LayerContainer
"""
HF EXAONE 4.0 model structure:

Exaone4ForCausalLM(
(model): Exaone4Model(
(embed_tokens): Embedding(102400, 5120)
(layers): ModuleList(
(0-63): 64 x Exaone4DecoderLayer(
(self_attn): Exaone4Attention(
(q_proj): Linear(in_features=5120, out_features=5120, bias=False)
(k_proj): Linear(in_features=5120, out_features=1024, bias=False)
(v_proj): Linear(in_features=5120, out_features=1024, bias=False)
(o_proj): Linear(in_features=5120, out_features=5120, bias=False)
(rotary_emb): Exaone4RotaryEmbedding()
)
(mlp): Exaone4MLP(
(gate_proj): Linear(in_features=5120, out_features=27392, bias=False)
(up_proj): Linear(in_features=5120, out_features=27392, bias=False)
(down_proj): Linear(in_features=27392, out_features=5120, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): Exaone4RMSNorm()
(post_attention_layernorm): Exaone4RMSNorm()
)
)
(norm): Exaone4RMSNorm()
)
(lm_head): Linear(in_features=5120, out_features=102400, bias=False)
)

Key EXAONE 4.0 features:
- Hybrid attention: sliding_attention (local) vs full_attention (global) layers
- Grouped Query Attention: 40 query heads, 8 key-value heads
- QK-Reorder-Norm: RMSNorm applied after Q/K projections
- SiLU activation in MLP
"""


class ExaoneTransformerContainer(LayerContainer):
"""
Transformer layer container for the EXAONE 4.0 model.
Handles both sliding_attention and full_attention layer types.
"""
qkv_w: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: GatedMLPParameter
mlp_2_w: MLP2Parameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter

PARAM_MAPPING = {
# Attention parameters - Q, K, V projections
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.o_proj.weight": "attn_out_w.params",

# MLP parameters - gate, up, down projections
"mlp.gate_proj.weight": "mlp_1_w.gate_params",
"mlp.up_proj.weight": "mlp_1_w.up_params",
"mlp.down_proj.weight": "mlp_2_w.params",

# Normalization parameters
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}


class ExaoneNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the EXAONE 4.0 model.
Contains embedding, final normalization, and output projection parameters.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm: NormParameter

PARAM_MAPPING = {
# Embedding and output parameters
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm.params",
"lm_head.weight": "word_unembed.params",
}
271 changes: 271 additions & 0 deletions deepspeed/inference/v2/model_implementations/exaone/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional, Tuple

import torch

import deepspeed.comm as dist

from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...ragged import RaggedBatchWrapper

from .container import ExaoneNonTransformerContainer, ExaoneTransformerContainer


class ExaoneInferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for EXAONE 4.0 models.

Key features:
- Hybrid attention: sliding_attention (local) vs full_attention (global) layers
- QK-Reorder-Norm: RMSNorm applied after Q/K projections
- Conditional RoPE: Skip RoPE for full_attention layers
- Grouped Query Attention: 40 query heads, 8 key-value heads
"""

_non_transformer: Optional[ExaoneNonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[ExaoneTransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""

# EXAONE 4.0 specific attributes
_layer_types: Optional[list] = None
"""
Layer types for hybrid attention: 'sliding_attention' or 'full_attention'
"""

def __init__(self, config, engine_config, base_mp_group):
super().__init__(config, engine_config, base_mp_group)

# Store layer types for hybrid attention handling
if hasattr(self._config, 'layer_types'):
self._layer_types = self._config.layer_types
else:
# Fallback: infer from sliding_window_pattern (LLLG = 3 local, 1 global)
pattern = getattr(self._config, 'sliding_window_pattern', 'LLLG')
layer_types = []
for i in range(self.num_layers):
if pattern[i % len(pattern)] == 'G':
layer_types.append('full_attention')
else:
layer_types.append('sliding_attention')
self._layer_types = layer_types

"""
Properties inherited from `DSInferenceModelBase`
"""

@property
def max_sequence_length(self) -> int:
return self._config.max_position_embeddings

"""
Properties inherited from `DSTransformerModelBase`
"""

@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.hidden_size

@property
def vocab_size(self) -> int:
return self._config.vocab_size

@property
def head_size(self) -> int:
return getattr(self._config, 'head_dim', self.model_dim // self.n_heads)

@property
def n_heads(self) -> int:
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size

@property
def n_heads_kv(self) -> int:
return self._config.num_key_value_heads

@property
def activation_dtype(self) -> DtypeEnum:
if self._config.torch_dtype == torch.float16:
return DtypeEnum.fp16
elif self._config.torch_dtype == torch.bfloat16:
return DtypeEnum.bf16
else:
raise NotImplementedError("Only fp16 and bf16 are supported")

@property
def mlp_activation_fn(self) -> ActivationType:
activation = self._config.hidden_act.lower()
# EXAONE 4.0 uses gated SiLU activation like LLaMA
if activation == "silu":
return ActivationType.SiGLU
elif activation == "gelu":
return ActivationType.GEGLU
elif activation == "relu":
return ActivationType.ReGLU
else:
raise NotImplementedError(f"Activation {activation} not supported")

@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.RMSNorm

@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)

"""
Helper methods for EXAONE 4.0 specific features
"""

def is_global_attention_layer(self, layer_idx: int) -> bool:
"""Check if layer uses global (full) attention vs local (sliding) attention"""
if self._layer_types and layer_idx < len(self._layer_types):
return self._layer_types[layer_idx] == 'full_attention'
return False

def should_apply_rope(self, layer_idx: int) -> bool:
"""EXAONE 4.0 skips RoPE for global attention layers"""
return not self.is_global_attention_layer(layer_idx)

"""
Forward implementations
"""

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.

Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.

Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)

if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one transformer layer with EXAONE 4.0 specific features:
- Hybrid attention (sliding vs full)
- QK-Reorder-Norm (RMSNorm after Q/K projections)
- Conditional RoPE (skip for global layers)

Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

# QKV projection
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None)

# EXAONE 4.0 attention with hybrid pattern and conditional RoPE
# NOTE: The attention module should handle QK-Reorder-Norm internally
# and respect the RoPE configuration based on layer type
if self.is_global_attention_layer(layer_idx):
# Global attention: full attention, no RoPE
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info, apply_rotary_pos_emb=False)
else:
# Local attention: sliding window, with RoPE
hidden_states = self.attn(hidden_states,
kv_cache,
ragged_batch_info,
apply_rotary_pos_emb=True,
sliding_window=getattr(self._config, 'sliding_window', 4096))

# Attention output projection
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

# Post-attention normalization
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)

# MLP forward pass (gated SiLU)
hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

# Prepare for next layer normalization
if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
else:
# On last layer, just perform the residual add
residual.add_(hidden_states)

return residual, hidden_states

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))

dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)

full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))

return full_logits
else:
return logits

def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Forward pass for EXAONE 4.0 model with hybrid attention support.
"""
residual = self._forward_embed(wrapped_batch)

# Initial normalization
residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None)

# Forward through all transformer layers
for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)
Loading