Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ def create_rope_const_params(self, interleave: bool = True):
)

if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_positions,
self.dim,
self.theta,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
self.qk_rope_head_dim = None
self.v_head_dim = None

self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)

self.num_heads = num_heads
Expand Down
15 changes: 12 additions & 3 deletions tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from transformers import PretrainedConfig

from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp

Expand Down Expand Up @@ -48,19 +48,28 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],


class NemotronNASAttention(Attention):
NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4")

def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int):
config = model_config.pretrained_config
is_neox = getattr(model_config.pretrained_config,
"position_embedding_type",
None) not in self.NON_NEOX_TYPES
rope = RopeParams.from_config(config)
if rope.scale_type == RotaryScalingType.yarn:
rope.mscale_all_dim = 0.0

super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads[layer_idx],
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
type=PositionEmbeddingType.rope_gpt_neox
if is_neox else PositionEmbeddingType.rope_gptj,
rope=rope,
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
Expand Down