@@ -716,6 +716,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
716716 { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
717717 { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
718718 { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
719+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
720+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
719721 },
720722 },
721723 {
@@ -1744,6 +1746,7 @@ enum e_model {
17441746 MODEL_4B,
17451747 MODEL_7B,
17461748 MODEL_8B,
1749+ MODEL_12B,
17471750 MODEL_13B,
17481751 MODEL_14B,
17491752 MODEL_15B,
@@ -3607,6 +3610,7 @@ static const char * llama_model_type_name(e_model type) {
36073610 case MODEL_3B: return "3B";
36083611 case MODEL_7B: return "7B";
36093612 case MODEL_8B: return "8B";
3613+ case MODEL_12B: return "12B";
36103614 case MODEL_13B: return "13B";
36113615 case MODEL_14B: return "14B";
36123616 case MODEL_15B: return "15B";
@@ -3898,6 +3902,7 @@ static void llm_load_hparams(
38983902 switch (hparams.n_layer) {
38993903 case 24: model.type = e_model::MODEL_1B; break;
39003904 case 32: model.type = e_model::MODEL_3B; break;
3905+ case 40: model.type = e_model::MODEL_12B; break;
39013906 default: model.type = e_model::MODEL_UNKNOWN;
39023907 }
39033908 } break;
@@ -5128,8 +5133,13 @@ static bool llm_load_tensors(
51285133 layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
51295134 layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
51305135
5131- layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
5132- layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
5136+ // optional q and k layernorms, present in StableLM 2 12B
5137+ layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false);
5138+ layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false);
5139+
5140+ // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
5141+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false);
5142+ layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false);
51335143
51345144 layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
51355145 layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@@ -8197,7 +8207,7 @@ struct llm_build_context {
81978207 struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
81988208
81998209 for (int il = 0; il < n_layer; ++il) {
8200- struct ggml_tensor * inpSA = inpL;
8210+
82018211
82028212 // norm
82038213 cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8206,6 +8216,8 @@ struct llm_build_context {
82068216 LLM_NORM, cb, il);
82078217 cb(cur, "attn_norm", il);
82088218
8219+ struct ggml_tensor * inpSA = cur;
8220+
82098221 // self-attention
82108222 {
82118223 // compute Q and K and RoPE them
@@ -8230,15 +8242,36 @@ struct llm_build_context {
82308242 cb(Vcur, "Vcur", il);
82318243 }
82328244
8245+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8246+ cb(Qcur, "Qcur", il);
8247+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8248+ cb(Kcur, "Kcur", il);
8249+
8250+ if (model.layers[il].attn_q_norm) {
8251+ Qcur = llm_build_norm(ctx0, Qcur, hparams,
8252+ model.layers[il].attn_q_norm,
8253+ NULL,
8254+ LLM_NORM, cb, il);
8255+ cb(Qcur, "Qcur", il);
8256+ }
8257+ if (model.layers[il].attn_k_norm) {
8258+ Kcur = llm_build_norm(ctx0, Kcur, hparams,
8259+ model.layers[il].attn_k_norm,
8260+ NULL,
8261+ LLM_NORM, cb, il);
8262+ cb(Kcur, "Kcur", il);
8263+ }
8264+
8265+
82338266 Qcur = ggml_rope_custom(
8234- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens) , inp_pos,
8267+ ctx0, Qcur, inp_pos,
82358268 n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
82368269 ext_factor, attn_factor, beta_fast, beta_slow
82378270 );
82388271 cb(Qcur, "Qcur", il);
82398272
82408273 Kcur = ggml_rope_custom(
8241- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens) , inp_pos,
8274+ ctx0, Kcur, inp_pos,
82428275 n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
82438276 ext_factor, attn_factor, beta_fast, beta_slow
82448277 );
@@ -8253,20 +8286,25 @@ struct llm_build_context {
82538286 // skip computing output for unused tokens
82548287 struct ggml_tensor * inp_out_ids = build_inp_out_ids();
82558288 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8289+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
82568290 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
82578291 }
82588292
8259- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA );
8293+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL );
82608294 cb(ffn_inp, "ffn_inp", il);
82618295
82628296 // feed-forward network
82638297 {
8264- cur = llm_build_norm(ctx0, ffn_inp, hparams,
8265- model.layers[il].ffn_norm,
8266- model.layers[il].ffn_norm_b,
8267- LLM_NORM, cb, il);
8268- cb(cur, "ffn_norm", il);
8269-
8298+ if (model.layers[il].ffn_norm) {
8299+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
8300+ model.layers[il].ffn_norm,
8301+ model.layers[il].ffn_norm_b,
8302+ LLM_NORM, cb, il);
8303+ cb(cur, "ffn_norm", il);
8304+ } else {
8305+ // parallel residual
8306+ cur = inpSA;
8307+ }
82708308 cur = llm_build_ffn(ctx0, cur,
82718309 model.layers[il].ffn_up, NULL,
82728310 model.layers[il].ffn_gate, NULL,
0 commit comments