@@ -1743,6 +1743,7 @@ struct llama_layer {
17431743struct llama_kv_cell {
17441744    llama_pos pos   = -1;
17451745    llama_pos delta = 0;
1746+     int32_t   src   = 0; // used by recurrent state models to copy states
17461747
17471748    std::set<llama_seq_id> seq_id;
17481749
@@ -1763,6 +1764,7 @@ struct llama_kv_cell {
17631764struct llama_kv_cache {
17641765    bool has_shift = false;
17651766    bool do_defrag = false;
1767+     bool do_copy   = false;
17661768    // with Mamba, a cell can hold the state for more than one past token
17671769    bool unlimited = false;
17681770
@@ -2001,7 +2003,8 @@ struct llama_context {
20012003    struct ggml_tensor * inp_K_shift;   // I32 [kv_size]
20022004    struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
20032005    struct ggml_tensor * inp_cls;       // I32 [n_batch]
2004-     struct ggml_tensor * inp_s_mask;    // F32 [kv_size] (only used by constant state models like Mamba)
2006+     struct ggml_tensor * inp_s_copy;    // I32 [kv_size]
2007+     struct ggml_tensor * inp_s_mask;    // F32 [kv_size]
20052008    struct ggml_tensor * inp_s_seq;     // I32 [kv_size, n_batch]
20062009
20072010#ifdef GGML_USE_MPI
@@ -2043,9 +2046,9 @@ static bool llama_kv_cache_init(
20432046
20442047    if (cache.unlimited) {
20452048        for (uint32_t i = 0; i < cache.size; ++i) {
2046-             cache.cells[i].delta  = i;
2049+             cache.cells[i].src  = i;
20472050        }
2048-     } // else, delta is already initialized to zero 
2051+     }
20492052
20502053#ifdef GGML_USE_CLBLAST
20512054    offload = false;
@@ -2296,19 +2299,20 @@ static void llama_kv_cache_seq_cp(
22962299
22972300    if (cache.unlimited) {
22982301        if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2299-             seq_id_src = cache.cells[seq_id_src].delta ;
2302+             seq_id_src = cache.cells[seq_id_src].src ;
23002303            GGML_ASSERT((uint32_t) seq_id_src < cache.size);
23012304            // intent to "copy from"
23022305            // supports copy chains thanks to taking the source of the source
2303-             cache.cells[seq_id_dst].delta  = seq_id_src;
2306+             cache.cells[seq_id_dst].src  = seq_id_src;
23042307
2305-             // prevent  the destination from getting cleared if  the source is not empty 
2308+             // preserve  the "keep or clear" status of  the copied sequence 
23062309            if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
23072310                cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2311+             } else {
2312+                 cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
23082313            }
2309-             // repurposed as a "need copy" flag
2310-             // (shifting can't be done anyway for this kind of KV cache)
2311-             cache.has_shift = true;
2314+ 
2315+             cache.do_copy = true;
23122316
23132317            cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
23142318        }
@@ -5352,6 +5356,25 @@ struct llm_build_context {
53525356        return gf;
53535357    }
53545358
5359+     struct ggml_cgraph * build_s_copy() {
5360+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5361+ 
5362+         for (int il = 0; il < n_layer; ++il) {
5363+             ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5364+             ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5365+ 
5366+             conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5367+             ssm_states  = ggml_get_rows(ctx0,  ssm_states, lctx.inp_s_copy);
5368+ 
5369+             // TODO: name the intermediate tensors with cb()
5370+ 
5371+             ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5372+             ggml_build_forward_expand(gf, ggml_cpy(ctx0,  ssm_states, kv_self.v_l[il]));
5373+         }
5374+ 
5375+         return gf;
5376+     }
5377+ 
53555378    struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
53565379        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
53575380
@@ -7816,16 +7839,6 @@ struct llm_build_context {
78167839            ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
78177840            ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il],  (d_state)*(d_inner), kv_self.size);
78187841
7819-             // do copies between states when needed (nothing to do with rope or shifts)
7820-             // TODO: do this in a another graph, a bit like build_k_shift
7821-             if (kv_self.has_shift) {
7822-                 conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
7823-                 ssm_states  = ggml_get_rows(ctx0,  ssm_states, lctx.inp_K_shift);
7824- 
7825-                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
7826-                 ggml_build_forward_expand(gf, ggml_cpy(ctx0,  ssm_states, kv_self.v_l[il]));
7827-             }
7828- 
78297842            // clear states of sequences which are starting at the beginning of this batch
78307843            {
78317844                ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
@@ -7978,6 +7991,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
79787991    return result;
79797992}
79807993
7994+ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
7995+     llama_batch dummy;
7996+     dummy.n_tokens = 0;
7997+ 
7998+     llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
7999+ 
8000+     struct llm_build_context llm(lctx, dummy, cb, false);
8001+ 
8002+     llm.init();
8003+ 
8004+     struct ggml_cgraph * result = llm.build_s_copy();
8005+ 
8006+     llm.free();
8007+ 
8008+     return result;
8009+ }
8010+ 
79818011static struct ggml_cgraph * llama_build_graph(
79828012         llama_context & lctx,
79838013     const llama_batch & batch,
@@ -8113,6 +8143,18 @@ static void llama_set_k_shift(llama_context & lctx) {
81138143    }
81148144}
81158145
8146+ static void llama_set_s_copy(llama_context & lctx) {
8147+     const int64_t kv_size = lctx.kv_self.size;
8148+ 
8149+     assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8150+ 
8151+     int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8152+ 
8153+     for (int i = 0; i < kv_size; ++i) {
8154+         data[i] = lctx.kv_self.cells[i].src;
8155+     }
8156+ }
8157+ 
81168158static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
81178159    //
81188160    // set input data
@@ -8227,17 +8269,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
82278269    }
82288270
82298271    if (kv_self.unlimited) {
8230-         const int64_t n_kv     = kv_self.n;
8272+         const int64_t n_kv = kv_self.n;
82318273
82328274        {
82338275            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
82348276            float * data = (float *) lctx.inp_s_mask->data;
82358277
82368278            // states which are not affected by the current batch are left untouched
82378279            for (int i = 0; i < n_kv; ++i) {
8238-                 llama_seq_id    seq_id         = i + lctx.kv_self.head;
8239-                 llama_kv_cell & kv_cell        = lctx.kv_self.cells[seq_id];
8240-                 bool            has_self_seq   = kv_cell.has_seq_id(seq_id);
8280+                 llama_seq_id    seq_id       = i + lctx.kv_self.head;
8281+                 llama_kv_cell & kv_cell      = lctx.kv_self.cells[seq_id];
8282+                 bool            has_self_seq = kv_cell.has_seq_id(seq_id);
82418283
82428284                data[i] = (float) has_self_seq;
82438285
@@ -8739,7 +8781,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
87398781            kv_self.has_shift = false;
87408782
87418783            for (uint32_t i = 0; i < kv_self.size; ++i) {
8742-                 kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
8784+                 kv_self.cells[i].delta = 0;
8785+             }
8786+         }
8787+     }
8788+ 
8789+     if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
8790+         llama_set_s_copy(lctx);
8791+ 
8792+         {
8793+             ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
8794+ 
8795+             llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
8796+         }
8797+ 
8798+         {
8799+             auto & kv_self = lctx.kv_self;
8800+ 
8801+             kv_self.do_copy = false;
8802+ 
8803+             for (uint32_t i = 0; i < kv_self.size; ++i) {
8804+                 kv_self.cells[i].src = i;
87438805            }
87448806        }
87458807    }
@@ -12418,7 +12480,7 @@ struct llama_context * llama_new_context_with_model(
1241812480        // graph inputs
1241912481        {
1242012482            ggml_init_params init_params = {
12421-                 /* .mem_size   */ ggml_tensor_overhead()*(8 + 2 *(ctx->kv_self.unlimited)),
12483+                 /* .mem_size   */ ggml_tensor_overhead()*(8 + 3 *(ctx->kv_self.unlimited)),
1242212484                /* .mem_buffer */ nullptr,
1242312485                /* .no_alloc   */ true,
1242412486            };
@@ -12433,6 +12495,7 @@ struct llama_context * llama_new_context_with_model(
1243312495            ctx->inp_mean    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
1243412496            ctx->inp_cls     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1243512497            if (ctx->kv_self.unlimited) {
12498+                 ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
1243612499                ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
1243712500                ctx->inp_s_seq  = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
1243812501            }
@@ -12446,6 +12509,7 @@ struct llama_context * llama_new_context_with_model(
1244612509            ggml_set_name(ctx->inp_mean,    "inp_mean");
1244712510            ggml_set_name(ctx->inp_cls,     "inp_cls");
1244812511            if (ctx->kv_self.unlimited) {
12512+                 ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
1244912513                ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
1245012514                ggml_set_name(ctx->inp_s_seq,  "inp_s_seq");
1245112515            }
0 commit comments