@@ -2815,6 +2815,42 @@ struct llama_kv_cache {
28152815 }
28162816};
28172817
2818+ class llama_kv_cache_state {
2819+ struct llama_kv_cache_state_short {
2820+ uint32_t head = 0;
2821+ uint32_t size = 0;
2822+ uint32_t used = 0;
2823+ uint32_t n = 0;
2824+
2825+ std::vector<llama_kv_cell> cells;
2826+ } old_state;
2827+
2828+ bool saved = false;
2829+
2830+ public:
2831+ void save_state(const llama_kv_cache& cache) {
2832+ old_state.head = cache.head;
2833+ old_state.size = cache.size;
2834+ old_state.used = cache.used;
2835+ old_state.n = cache.n;
2836+ old_state.cells = cache.cells;
2837+
2838+ saved = true;
2839+ }
2840+
2841+ void restore(llama_kv_cache& cache) {
2842+ if (saved) {
2843+ cache.head = old_state.head;
2844+ cache.size = old_state.size;
2845+ cache.used = old_state.used;
2846+ cache.n = old_state.n;
2847+ cache.cells = std::move(old_state.cells);
2848+
2849+ saved = false;
2850+ }
2851+ }
2852+ };
2853+
28182854struct llama_control_vector {
28192855 std::vector<struct ggml_tensor *> tensors; // per layer
28202856 std::vector<struct ggml_context *> ctxs;
@@ -17184,6 +17220,7 @@ static int llama_decode_internal(
1718417220 lctx.n_queued_tokens += n_tokens_all;
1718517221
1718617222 auto & kv_self = lctx.kv_self;
17223+ llama_kv_cache_state kv_cache_state_holder;
1718717224
1718817225 const int64_t n_embd = hparams.n_embd;
1718917226 const int64_t n_vocab = hparams.n_vocab;
@@ -17261,6 +17298,7 @@ static int llama_decode_internal(
1726117298 // non-causal masks do not use the KV cache
1726217299 if (hparams.causal_attn) {
1726317300 llama_kv_cache_update(&lctx);
17301+ kv_cache_state_holder.save_state(kv_self);
1726417302
1726517303 // if we have enough unused cells before the current head ->
1726617304 // better to start searching from the beginning of the cache, hoping to fill it
@@ -17318,16 +17356,17 @@ static int llama_decode_internal(
1731817356 llama_set_inputs(lctx, ubatch);
1731917357
1732017358 const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321- switch (compute_status) {
17322- case GGML_STATUS_SUCCESS:
17323- break;
17324- case GGML_STATUS_ABORTED:
17325- return 2;
17326- case GGML_STATUS_ALLOC_FAILED:
17327- return -2;
17328- case GGML_STATUS_FAILED:
17329- default:
17330- return -3;
17359+ if (compute_status != GGML_STATUS_SUCCESS) {
17360+ kv_cache_state_holder.restore(kv_self);
17361+ switch (compute_status) {
17362+ case GGML_STATUS_ABORTED:
17363+ return 2;
17364+ case GGML_STATUS_ALLOC_FAILED:
17365+ return -2;
17366+ case GGML_STATUS_FAILED:
17367+ default:
17368+ return -3;
17369+ }
1733117370 }
1733217371
1733317372 // update the kv ring buffer
0 commit comments