@@ -9036,8 +9036,8 @@ static int llama_decode_internal(
90369036 //llama_synchronize(&lctx);
90379037
90389038 // decide if we need to defrag the kv cache
9039- if (cparams.defrag_thold >= 0.0f) {
9040- const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens_all )/float(kv_self.n) : 0.0f;
9039+ if (cparams.causal_attn && cparams. defrag_thold >= 0.0f) {
9040+ const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
90419041
90429042 // queue defragmentation for next llama_kv_cache_update
90439043 if (fragmentation > cparams.defrag_thold) {
@@ -9069,6 +9069,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
90699069 // number of cells moved
90709070 uint32_t n_moves = 0;
90719071
9072+ // each move requires 6*n_layer tensors (see build_defrag)
9073+ // - source view, destination view, copy operation
9074+ // - x2 for keys and values
9075+ const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
9076+
90729077 // determine which KV cells to move where
90739078 //
90749079 // cell i moves to ids[i]
@@ -9095,15 +9100,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
90959100 nh++;
90969101 }
90979102
9098- // each move requires 6*n_layer tensors (see build_defrag)
9099- // - source view, destination view, copy operation
9100- // - x2 for keys and values
9101- //
9102- if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
9103- // the graph is too big, we cannot move more cells
9104- break;
9105- }
9106-
91079103 uint32_t nf = 0;
91089104 uint32_t is = n_kv - 1;
91099105
@@ -9133,11 +9129,19 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
91339129 // are we moving a continuous block of memory?
91349130 bool cont = false;
91359131
9132+ // should we stop searching for the next move?
9133+ bool stop = false;
9134+
91369135 // go back and move the nf cells to the hole
91379136 for (; i1 < n_kv; ++i1) {
91389137 auto & cell1 = kv_self.cells[i1];
91399138
91409139 if (cell1.is_empty() || ids[i1] != n_kv) {
9140+ if (n_moves == max_moves) {
9141+ stop = true;
9142+ break;
9143+ }
9144+
91419145 cont = false;
91429146 continue;
91439147 }
@@ -9164,6 +9168,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
91649168 }
91659169 }
91669170
9171+ if (stop || n_moves == max_moves) {
9172+ break;
9173+ }
9174+
91679175 //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
91689176
91699177 i0 += nh - 1;
0 commit comments