@@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init(
35023502    return true;
35033503}
35043504
3505+ // a structure holds information about the slot found in llama_kv_cache_find_slot
3506+ struct llama_kv_cache_slot_info {
3507+     std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
3508+     bool found = false;                       // the slot was found
3509+ 
3510+     explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3511+     llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3512+ 
3513+     operator bool() const { return found; }
3514+ };
3515+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3516+ 
35053517// find an empty slot of size "n_tokens" in the cache
35063518// updates the cache head
3519+ // returns a structure holding information about the slot found
35073520// Note: On success, it's important that cache.head points
35083521// to the first cell of the slot.
3509- static bool  llama_kv_cache_find_slot(
3522+ static struct llama_kv_cache_slot_info  llama_kv_cache_find_slot(
35103523           struct llama_kv_cache & cache,
35113524       const struct llama_ubatch & batch) {
35123525    const uint32_t n_tokens = batch.n_tokens;
@@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
35343547                    // too big seq_id
35353548                    // TODO: would it be possible to resize the cache instead?
35363549                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3537-                     return false ;
3550+                     return llama_kv_cache_slot_info_failed ;
35383551                }
35393552                if (j > 0) {
35403553                    llama_kv_cell & seq = cache.cells[seq_id];
@@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot(
36693682        // allow getting the range of used cells, from head to head + n
36703683        cache.head = min;
36713684        cache.n    = max - min + 1;
3685+         cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3686+             [](const llama_kv_cell& cell){ return !cell.is_empty(); });
36723687
36733688        // sanity check
3674-         return cache.n >= n_seqs;
3689+         return llama_kv_cache_slot_info( cache.n >= n_seqs) ;
36753690    }
36763691    // otherwise, one cell per token.
36773692
36783693    if (n_tokens > cache.size) {
36793694        LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3680-         return false ;
3695+         return llama_kv_cache_slot_info_failed ;
36813696    }
36823697
36833698    uint32_t n_tested = 0;
@@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot(
37053720
37063721        if (n_tested >= cache.size) {
37073722            //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3708-             return false ;
3723+             return llama_kv_cache_slot_info_failed ;
37093724        }
37103725    }
37113726
@@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot(
37223737
37233738    cache.used += n_tokens;
37243739
3725-     return true ;
3740+     return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens) ;
37263741}
37273742
37283743// find how many cells are currently in use
@@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
39984013    return cparams.flash_attn ? 256u : 32u;
39994014}
40004015
4016+ // saves the kv_cache state for future recovery.
4017+ // used to rollback llama_kv_cache_find_slot changes.
4018+ struct llama_kv_slot_restorer {
4019+     struct llama_kv_cache_state {
4020+         uint32_t head = 0;
4021+         uint32_t n    = 0;
4022+     } old_state;
4023+ 
4024+     // for non-recurrent models only
4025+     // list of slots to restore
4026+     std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
4027+ 
4028+     bool do_restore = false;
4029+ 
4030+     explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4031+         old_state.head  = cache.head;
4032+         old_state.n     = cache.n;
4033+     }
4034+ 
4035+     // saves a slot information for future restoration
4036+     void save(const struct llama_kv_cache_slot_info & slot) {
4037+         if (slot) {
4038+             do_restore = true;
4039+             if (slot.boundaries.first != slot.boundaries.second) {
4040+                 slot_boundaries.push_back(slot.boundaries);
4041+             }
4042+         }
4043+     }
4044+ 
4045+     // must be explicitly called to restore the kv_cache state
4046+     // and rollback changes from all llama_kv_cache_find_slot calls
4047+     void restore(struct llama_kv_cache & cache) {
4048+         if (do_restore) {
4049+             cache.head  = old_state.head;
4050+             cache.n     = old_state.n;
4051+ 
4052+             if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4053+                 llama_kv_cache_seq_rm(cache, -1, -1, -1);
4054+             } else {
4055+                 for (auto & slot : slot_boundaries) {
4056+                     llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4057+                 }
4058+             }
4059+         }
4060+     }
4061+ };
4062+ 
40014063//
40024064// model loading and saving
40034065//
@@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
1718117243    }
1718217244}
1718317245
17184- static void llama_graph_compute(
17246+ // returns the result of ggml_backend_sched_graph_compute_async execution
17247+ static enum ggml_status llama_graph_compute(
1718517248          llama_context & lctx,
1718617249            ggml_cgraph * gf,
1718717250                    int   n_threads,
@@ -17196,15 +17259,20 @@ static void llama_graph_compute(
1719617259        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
1719717260    }
1719817261
17199-     auto err  = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17200-     if (err  != GGML_STATUS_SUCCESS) {
17201-         LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err );
17262+     auto status  = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17263+     if (status  != GGML_STATUS_SUCCESS) {
17264+         LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status );
1720217265    }
1720317266
1720417267    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
17268+ 
17269+     return status;
1720517270}
1720617271
1720717272// decode a batch of tokens by evaluating the transformer
17273+ // in case of unsuccessful decoding (error or warning),
17274+ // the kv_cache state will be returned to its original state
17275+ // (for non-recurrent models) or cleaned (for recurrent models)
1720817276//
1720917277//   - lctx:      llama context
1721017278//   - batch:     batch to evaluate
@@ -17254,6 +17322,7 @@ static int llama_decode_internal(
1725417322    lctx.n_queued_tokens += n_tokens_all;
1725517323
1725617324    auto & kv_self = lctx.kv_self;
17325+     llama_kv_slot_restorer kv_slot_restorer(kv_self);
1725717326
1725817327    const int64_t n_embd  = hparams.n_embd;
1725917328    const int64_t n_vocab = hparams.n_vocab;
@@ -17338,9 +17407,11 @@ static int llama_decode_internal(
1733817407                kv_self.head = 0;
1733917408            }
1734017409
17341-             if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17410+             const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17411+             if (!slot) {
1734217412                return 1;
1734317413            }
17414+             kv_slot_restorer.save(slot);
1734417415
1734517416            if (!kv_self.recurrent) {
1734617417                // a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17387,7 +17458,19 @@ static int llama_decode_internal(
1738717458
1738817459        llama_set_inputs(lctx, ubatch);
1738917460
17390-         llama_graph_compute(lctx, gf, n_threads, threadpool);
17461+         const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17462+         if (compute_status != GGML_STATUS_SUCCESS) {
17463+             kv_slot_restorer.restore(kv_self);
17464+             switch (compute_status) {
17465+                 case GGML_STATUS_ABORTED:
17466+                     return 2;
17467+                 case GGML_STATUS_ALLOC_FAILED:
17468+                     return -2;
17469+                 case GGML_STATUS_FAILED:
17470+                 default:
17471+                     return -3;
17472+             }
17473+         }
1739117474
1739217475        // update the kv ring buffer
1739317476        {
@@ -17624,7 +17707,18 @@ static int llama_encode_internal(
1762417707
1762517708    llama_set_inputs(lctx, ubatch);
1762617709
17627-     llama_graph_compute(lctx, gf, n_threads, threadpool);
17710+     const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17711+     switch (compute_status) {
17712+         case GGML_STATUS_SUCCESS:
17713+             break;
17714+         case GGML_STATUS_ABORTED:
17715+             return 2;
17716+         case GGML_STATUS_ALLOC_FAILED:
17717+             return -2;
17718+         case GGML_STATUS_FAILED:
17719+         default:
17720+             return -3;
17721+     }
1762817722
1762917723    // extract embeddings
1763017724    if (embd) {
0 commit comments