Skip to content

Commit 600896b

Browse files
committed
llama : move rope factors from KV header to tensors
1 parent d93b5ca commit 600896b

File tree

4 files changed

+46
-73
lines changed

4 files changed

+46
-73
lines changed

convert-hf-to-gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,8 +1834,8 @@ def set_gguf_parameters(self):
18341834
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
18351835
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
18361836

1837-
self.gguf_writer.add_rope_scaling_freq_long_factors(long_factors)
1838-
self.gguf_writer.add_rope_scaling_freq_short_factors(short_factors)
1837+
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
1838+
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
18391839

18401840

18411841
@Model.register("PlamoForCausalLM")

gguf-py/gguf/constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ class Rope:
6161
FREQ_BASE = "{arch}.rope.freq_base"
6262
SCALING_TYPE = "{arch}.rope.scaling.type"
6363
SCALING_FACTOR = "{arch}.rope.scaling.factor"
64-
SCALING_LONG_FACTORS = "{arch}.rope.scaling.freq_long_factors"
65-
SCALING_SHORT_FACTORS = "{arch}.rope.scaling.freq_short_factors"
6664
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
6765
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
6866
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
@@ -151,6 +149,8 @@ class MODEL_TENSOR(IntEnum):
151149
OUTPUT = auto()
152150
OUTPUT_NORM = auto()
153151
ROPE_FREQS = auto()
152+
ROPE_FACTORS_LONG = auto()
153+
ROPE_FACTORS_SHORT = auto()
154154
ATTN_Q = auto()
155155
ATTN_K = auto()
156156
ATTN_V = auto()
@@ -228,6 +228,8 @@ class MODEL_TENSOR(IntEnum):
228228
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
229229
MODEL_TENSOR.OUTPUT: "output",
230230
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
231+
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
232+
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
231233
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
232234
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
233235
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",

gguf-py/gguf/gguf_writer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,6 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None:
433433
def add_rope_scaling_factor(self, value: float) -> None:
434434
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
435435

436-
def add_rope_scaling_freq_long_factors(self, value: Sequence[float]) -> None:
437-
self.add_array(Keys.Rope.SCALING_LONG_FACTORS.format(arch=self.arch), value)
438-
439-
def add_rope_scaling_freq_short_factors(self, value: Sequence[float]) -> None:
440-
self.add_array(Keys.Rope.SCALING_SHORT_FACTORS.format(arch=self.arch), value)
441-
442436
def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
443437
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
444438

llama.cpp

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ enum llm_kv {
304304
LLM_KV_ROPE_SCALE_LINEAR,
305305
LLM_KV_ROPE_SCALING_TYPE,
306306
LLM_KV_ROPE_SCALING_FACTOR,
307-
LLM_KV_ROPE_SCALING_LONG_FACTORS,
308-
LLM_KV_ROPE_SCALING_SHORT_FACTORS,
309307
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
310308
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
311309
LLM_KV_ROPE_SCALING_FINETUNED,
@@ -384,8 +382,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
384382
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
385383
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
386384
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
387-
{ LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" },
388-
{ LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" },
389385
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
390386
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
391387
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
@@ -442,6 +438,8 @@ enum llm_tensor {
442438
LLM_TENSOR_OUTPUT,
443439
LLM_TENSOR_OUTPUT_NORM,
444440
LLM_TENSOR_ROPE_FREQS,
441+
LLM_TENSOR_ROPE_FACTORS_LONG,
442+
LLM_TENSOR_ROPE_FACTORS_SHORT,
445443
LLM_TENSOR_ATTN_Q,
446444
LLM_TENSOR_ATTN_K,
447445
LLM_TENSOR_ATTN_V,
@@ -809,18 +807,20 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
809807
{
810808
LLM_ARCH_PHI3,
811809
{
812-
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
813-
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
814-
{ LLM_TENSOR_OUTPUT, "output" },
815-
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
816-
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
817-
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
818-
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
819-
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
820-
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
821-
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
822-
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
823-
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
810+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
811+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
812+
{ LLM_TENSOR_OUTPUT, "output" },
813+
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
814+
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
815+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
816+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
817+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
818+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
819+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
820+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
821+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
822+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
823+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
824824
},
825825
},
826826
{
@@ -1756,14 +1756,11 @@ struct llama_hparams {
17561756
float f_norm_eps;
17571757
float f_norm_rms_eps;
17581758

1759+
float rope_attn_factor = 1.0f;
17591760
float rope_freq_base_train;
17601761
float rope_freq_scale_train;
17611762
uint32_t n_yarn_orig_ctx;
17621763

1763-
std::vector<float> rope_long_factors;
1764-
std::vector<float> rope_short_factors;
1765-
float rope_attn_factor = 1.0f;
1766-
17671764
// for State Space Models
17681765
uint32_t ssm_d_conv = 0;
17691766
uint32_t ssm_d_inner = 0;
@@ -1799,10 +1796,6 @@ struct llama_hparams {
17991796
if (this->rope_finetuned != other.rope_finetuned) return true;
18001797
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
18011798

1802-
if (this->rope_long_factors != other.rope_long_factors) return true;
1803-
if (this->rope_short_factors != other.rope_short_factors) return true;
1804-
if (this->rope_attn_factor != other.rope_attn_factor) return true;
1805-
18061799
if (this->ssm_d_conv != other.ssm_d_conv) return true;
18071800
if (this->ssm_d_inner != other.ssm_d_inner) return true;
18081801
if (this->ssm_d_state != other.ssm_d_state) return true;
@@ -1812,6 +1805,7 @@ struct llama_hparams {
18121805

18131806
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
18141807
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
1808+
if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true;
18151809
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
18161810
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
18171811

@@ -2117,6 +2111,10 @@ struct llama_model {
21172111
struct ggml_tensor * output;
21182112
struct ggml_tensor * output_b;
21192113

2114+
// long rope factors
2115+
struct ggml_tensor * rope_long;
2116+
struct ggml_tensor * rope_short;
2117+
21202118
std::vector<llama_layer> layers;
21212119

21222120
llama_split_mode split_mode;
@@ -2260,8 +2258,6 @@ struct llama_context {
22602258
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
22612259
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
22622260

2263-
struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2]
2264-
22652261
// control vectors
22662262
struct llama_control_vector cvec;
22672263
};
@@ -3898,12 +3894,6 @@ static void llm_load_hparams(
38983894
}
38993895
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
39003896

3901-
ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false);
3902-
ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false);
3903-
3904-
GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2);
3905-
GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size());
3906-
39073897
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
39083898

39093899
// sanity check for n_rot (optional)
@@ -4937,6 +4927,7 @@ static bool llm_load_tensors(
49374927
// create tensors for the weights
49384928
{
49394929
const int64_t n_embd = hparams.n_embd;
4930+
const int64_t n_embd_head = n_embd / hparams.n_head;
49404931
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
49414932
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
49424933
const int64_t n_embd_gqa = n_embd_v_gqa;
@@ -5648,6 +5639,9 @@ static bool llm_load_tensors(
56485639
{
56495640
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
56505641

5642+
model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false);
5643+
model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false);
5644+
56515645
// output
56525646
{
56535647
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
@@ -6878,7 +6872,7 @@ struct llm_build_context {
68786872
cb(lctx.inp_K_shift, "K_shift", -1);
68796873
ggml_set_input(lctx.inp_K_shift);
68806874

6881-
lctx.freq_factors = build_freq_factors();
6875+
struct ggml_tensor * rope_factors = build_rope_factors();
68826876

68836877
for (int il = 0; il < n_layer; ++il) {
68846878
struct ggml_tensor * tmp =
@@ -6889,7 +6883,7 @@ struct llm_build_context {
68896883
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
68906884
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
68916885
0),
6892-
lctx.inp_K_shift, lctx.freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
6886+
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
68936887
ext_factor, attn_factor, beta_fast, beta_slow);
68946888

68956889
cb(tmp, "K_shifted", il);
@@ -6994,17 +6988,15 @@ struct llm_build_context {
69946988
return lctx.inp_pos;
69956989
}
69966990

6997-
struct ggml_tensor * build_freq_factors() {
6998-
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
6999-
lctx.freq_factors = nullptr;
7000-
return nullptr;
7001-
}
6991+
struct ggml_tensor * build_rope_factors() {
6992+
// choose long/short freq factors based on the context size
6993+
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
70026994

7003-
lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2);
7004-
cb(lctx.freq_factors, "freq_factors", -1);
7005-
ggml_set_input(lctx.freq_factors);
6995+
if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
6996+
return model.rope_long;
6997+
}
70066998

7007-
return lctx.freq_factors;
6999+
return model.rope_short;
70087000
}
70097001

70107002
struct ggml_tensor * build_inp_out_ids() {
@@ -9126,7 +9118,9 @@ struct llm_build_context {
91269118
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
91279119

91289120
// rope freq factors for 128k context
9129-
struct ggml_tensor* freq_factors = build_freq_factors();
9121+
struct ggml_tensor * rope_factors = build_rope_factors();
9122+
9123+
GGML_ASSERT(rope_factors != nullptr && "rope_factors is required for phi3"); // TMP: remove me
91309124

91319125
for (int il = 0; il < n_layer; ++il) {
91329126
auto residual = inpL;
@@ -9165,7 +9159,7 @@ struct llm_build_context {
91659159
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
91669160

91679161
Qcur = ggml_rope_ext(
9168-
ctx0, Qcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx,
9162+
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
91699163
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
91709164
);
91719165
cb(Qcur, "Qcur", il);
@@ -9174,7 +9168,7 @@ struct llm_build_context {
91749168
cb(Qcur, "Qcur", il);
91759169

91769170
Kcur = ggml_rope_ext(
9177-
ctx0, Kcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx,
9171+
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
91789172
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
91799173
);
91809174
cb(Kcur, "Kcur", il);
@@ -10966,23 +10960,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1096610960
}
1096710961
}
1096810962

10969-
if (lctx.freq_factors) {
10970-
// TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter
10971-
const auto freq_dim = hparams.n_embd_head_k / 2;
10972-
10973-
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
10974-
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
10975-
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
10976-
10977-
// choose long/short freq factors based on the context size
10978-
const auto n_ctx = llama_n_ctx(&lctx);
10979-
if (n_ctx > hparams.n_yarn_orig_ctx) {
10980-
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
10981-
} else {
10982-
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
10983-
}
10984-
}
10985-
1098610963
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
1098710964
const int64_t n_tokens = batch.n_tokens;
1098810965

0 commit comments

Comments
 (0)