|
| 1 | +from typing import Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from tensorrt_llm._torch.model_config import ModelConfig |
| 7 | +from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ |
| 8 | + Qwen2MoeHfWeightMapper |
| 9 | +from tensorrt_llm._torch.models.modeling_nemotron_h import split |
| 10 | +from tensorrt_llm._torch.models.modeling_utils import register_mapper |
| 11 | +from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM |
| 12 | + |
| 13 | + |
| 14 | +@register_mapper("HF", "Qwen3NextForCausalLM") |
| 15 | +class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper): |
| 16 | + |
| 17 | + def init_model_and_config(self, model: Union[nn.Module, |
| 18 | + DecoderModelForCausalLM], |
| 19 | + config: ModelConfig): |
| 20 | + super().init_model_and_config(model, config) |
| 21 | + self._num_kv_heads = model.config.num_key_value_heads if hasattr( |
| 22 | + model.config, 'num_key_value_heads' |
| 23 | + ) and model.config.num_key_value_heads is not None else model.config.num_attention_heads |
| 24 | + |
| 25 | + def should_skip_module(self, module_name: str) -> bool: |
| 26 | + if module_name.startswith("draft_model"): |
| 27 | + return True |
| 28 | + return super().should_skip_module(module_name) |
| 29 | + |
| 30 | + def _duplicate_kv_weights(self, module: nn.Module, new_name: str, |
| 31 | + weights: dict): |
| 32 | + tensors_to_duplicate = ["weight", "bias"] |
| 33 | + if module.quant_config.quant_mode.has_nvfp4(): |
| 34 | + tensors_to_duplicate.append("weight_scale") |
| 35 | + if module.quant_config.quant_mode.has_fp8_block_scales(): |
| 36 | + tensors_to_duplicate.append("weight_scale_inv") |
| 37 | + |
| 38 | + if new_name in ['k_proj', 'v_proj']: |
| 39 | + num_kv_heads_list = [self._num_kv_heads |
| 40 | + ] * len(weights) if isinstance( |
| 41 | + self._num_kv_heads, |
| 42 | + int) else self._num_kv_heads |
| 43 | + processed_weights = { |
| 44 | + k: |
| 45 | + self._duplicate_kv(weight=v[:], |
| 46 | + num_kv_heads=num_kv_heads_list[i], |
| 47 | + tensor_parallel_size=self._tp_size) |
| 48 | + if k in tensors_to_duplicate else v |
| 49 | + for i, (k, v) in enumerate(weights.items()) |
| 50 | + } |
| 51 | + return processed_weights |
| 52 | + |
| 53 | + return weights |
| 54 | + |
| 55 | + def preprocess_weights(self, weights: dict) -> dict: |
| 56 | + config = self.config.pretrained_config |
| 57 | + tp_size = self.config.mapping.tp_size |
| 58 | + tp_rank = self.config.mapping.tp_rank |
| 59 | + |
| 60 | + # linear_num_value_heads = config.linear_num_value_heads |
| 61 | + # linear_num_key_heads = config.linear_num_key_heads |
| 62 | + # linear_key_head_dim = config.linear_key_head_dim |
| 63 | + # linear_value_head_dim = config.linear_value_head_dim |
| 64 | + linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128 |
| 65 | + linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128 |
| 66 | + |
| 67 | + new_weights = {} |
| 68 | + for name, _ in weights.items(): |
| 69 | + key = name |
| 70 | + |
| 71 | + if "A_log" in key: |
| 72 | + w = split(weights[name], tp_size, tp_rank) |
| 73 | + w = w.to(torch.float32) |
| 74 | + new_weights[key] = w |
| 75 | + elif "dt_bias" in key: |
| 76 | + w = split(weights[name], tp_size, tp_rank) |
| 77 | + w = w.to(torch.float32) |
| 78 | + new_weights[key] = w |
| 79 | + elif "in_proj" in key: |
| 80 | + # Don't need to split in_proj weight based on the implementation of reference. |
| 81 | + # Need to know the reason. |
| 82 | + new_weights[key] = weights[name] |
| 83 | + elif "conv1d" in key: |
| 84 | + w = weights[name] |
| 85 | + # removing dim(1) because we are using Linear to store conv1d weights |
| 86 | + if "weight" in key: |
| 87 | + w = w.squeeze(1) |
| 88 | + |
| 89 | + conv_q, conv_k, conv_v = torch.split( |
| 90 | + w, [linear_key_dim, linear_key_dim, linear_value_dim], |
| 91 | + dim=0) |
| 92 | + |
| 93 | + w = [] |
| 94 | + for rank in range(tp_size): |
| 95 | + conv_q_rank = split(conv_q, tp_size, rank) |
| 96 | + conv_k_rank = split(conv_k, tp_size, rank) |
| 97 | + conv_v_rank = split(conv_v, tp_size, rank) |
| 98 | + y = torch.concat([conv_q_rank, conv_k_rank, conv_v_rank]) |
| 99 | + w.append(y) |
| 100 | + w = torch.concat(w).contiguous() |
| 101 | + new_weights[key] = w |
| 102 | + else: |
| 103 | + new_weights[key] = weights[name] |
| 104 | + |
| 105 | + return new_weights |
0 commit comments