@@ -2806,6 +2806,42 @@ struct llama_kv_cache {
28062806 }
28072807};
28082808
2809+ class llama_kv_cache_state {
2810+ struct llama_kv_cache_state_short {
2811+ uint32_t head = 0;
2812+ uint32_t size = 0;
2813+ uint32_t used = 0;
2814+ uint32_t n = 0;
2815+
2816+ std::vector<llama_kv_cell> cells;
2817+ } old_state;
2818+
2819+ bool saved = false;
2820+
2821+ public:
2822+ void save_state(const llama_kv_cache& cache) {
2823+ old_state.head = cache.head;
2824+ old_state.size = cache.size;
2825+ old_state.used = cache.used;
2826+ old_state.n = cache.n;
2827+ old_state.cells = cache.cells;
2828+
2829+ saved = true;
2830+ }
2831+
2832+ void restore(llama_kv_cache& cache) {
2833+ if (saved) {
2834+ cache.head = old_state.head;
2835+ cache.size = old_state.size;
2836+ cache.used = old_state.used;
2837+ cache.n = old_state.n;
2838+ cache.cells = std::move(old_state.cells);
2839+
2840+ saved = false;
2841+ }
2842+ }
2843+ };
2844+
28092845struct llama_control_vector {
28102846 std::vector<struct ggml_tensor *> tensors; // per layer
28112847 std::vector<struct ggml_context *> ctxs;
@@ -16687,6 +16723,7 @@ static int llama_decode_internal(
1668716723 lctx.n_queued_tokens += n_tokens_all;
1668816724
1668916725 auto & kv_self = lctx.kv_self;
16726+ llama_kv_cache_state kv_cache_state_holder;
1669016727
1669116728 const int64_t n_embd = hparams.n_embd;
1669216729 const int64_t n_vocab = hparams.n_vocab;
@@ -16764,6 +16801,7 @@ static int llama_decode_internal(
1676416801 // non-causal masks do not use the KV cache
1676516802 if (hparams.causal_attn) {
1676616803 llama_kv_cache_update(&lctx);
16804+ kv_cache_state_holder.save_state(kv_self);
1676716805
1676816806 // if we have enough unused cells before the current head ->
1676916807 // better to start searching from the beginning of the cache, hoping to fill it
@@ -16821,16 +16859,17 @@ static int llama_decode_internal(
1682116859 llama_set_inputs(lctx, ubatch);
1682216860
1682316861 const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
16824- switch (compute_status) {
16825- case GGML_STATUS_SUCCESS:
16826- break;
16827- case GGML_STATUS_ABORTED:
16828- return 2;
16829- case GGML_STATUS_ALLOC_FAILED:
16830- return -2;
16831- case GGML_STATUS_FAILED:
16832- default:
16833- return -3;
16862+ if (compute_status != GGML_STATUS_SUCCESS) {
16863+ kv_cache_state_holder.restore(kv_self);
16864+ switch (compute_status) {
16865+ case GGML_STATUS_ABORTED:
16866+ return 2;
16867+ case GGML_STATUS_ALLOC_FAILED:
16868+ return -2;
16869+ case GGML_STATUS_FAILED:
16870+ default:
16871+ return -3;
16872+ }
1683416873 }
1683516874
1683616875 // update the kv ring buffer
0 commit comments