@@ -633,7 +633,15 @@ bool llama_context::apply_adapter_cvec(
633633    return  cvec.apply (model, data, len, n_embd, il_start, il_end);
634634}
635635
636- llm_graph_result_ptr llama_context::process (const  llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
636+ llm_graph_result_ptr llama_context::process_ubatch (const  llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status * ret) {
637+     if  (mstate && !mstate->apply ()) {
638+         LLAMA_LOG_ERROR (" %s: failed to apply memory state\n "  , __func__);
639+         if  (ret) {
640+             *ret = GGML_STATUS_FAILED;
641+         }
642+         return  nullptr ;
643+     }
644+ 
637645    auto  * gf = graph_init ();
638646    if  (!gf) {
639647        LLAMA_LOG_ERROR (" %s: failed to initialize graph\n "  , __func__);
@@ -748,7 +756,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748756    cparams.causal_attn  = false ;
749757
750758    ggml_status status;
751-     auto  res = process (ubatch, LLM_GRAPH_TYPE_ENCODER, &status);
759+     const   auto  res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER,  nullptr , &status);
752760
753761    cparams.causal_attn  = causal_attn_org;
754762
@@ -927,12 +935,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927935    //  handle any pending defrags/shifts
928936    kv_self_update ();
929937
930-     auto  decode_state  = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /*  logits_all */   n_outputs_all == n_tokens_all);
931-     if  (!decode_state ) {
938+     auto  kv_state  = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /*  logits_all */   n_outputs_all == n_tokens_all);
939+     if  (!kv_state ) {
932940        return  -2 ;
933941    }
934942
935-     switch  (decode_state ->get_status ()) {
943+     switch  (kv_state ->get_status ()) {
936944        case  LLAMA_MEMORY_STATUS_SUCCESS:
937945            {
938946            } break ;
@@ -955,8 +963,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955963
956964    int64_t  n_outputs_prev = 0 ;
957965
958-     while  ( const   auto  * ubatch_ptr = decode_state-> next ())  {
959-         const  auto  & ubatch = *ubatch_ptr ;
966+     do  {
967+         const  auto  & ubatch = kv_state-> get_ubatch () ;
960968
961969        //  count the outputs in this u_batch
962970        {
@@ -979,7 +987,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979987        ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
980988
981989        ggml_status status;
982-         auto  res = process (ubatch, LLM_GRAPH_TYPE_DECODER, &status);
990+         const   auto  res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, kv_state. get () , &status);
983991
984992        if  (!res) {
985993            //  the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1100,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10921100        }
10931101
10941102        n_outputs_prev += n_outputs;
1095-     }
1103+     }  while  (kv_state-> next ()); 
10961104
10971105    //  set to total number of outputs in the batch, for use in llama_get_logits_ith
10981106    n_outputs = n_outputs_all;
@@ -1101,7 +1109,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11011109    {
11021110        bool  sorted_output = true ;
11031111
1104-         auto  & out_ids = decode_state ->out_ids ();
1112+         auto  & out_ids = kv_state ->out_ids ();
11051113
11061114        GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
11071115
@@ -2020,8 +2028,8 @@ void llama_context::opt_epoch_iter(
20202028
20212029        int64_t  n_outputs_all = n_tokens_all;
20222030
2023-         auto  decode_state  = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /*  logits_all */   true );
2024-         if  (!decode_state  || decode_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2031+         auto  kv_state  = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /*  logits_all */   true );
2032+         if  (!kv_state  || kv_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20252033            LLAMA_LOG_ERROR (" %s: could not initialize batch\n "  , __func__);
20262034            break ;
20272035        }
@@ -2033,8 +2041,8 @@ void llama_context::opt_epoch_iter(
20332041        };
20342042
20352043        uint32_t  pos_batch = 0 ;
2036-         while  ( const   auto  * ubatch_ptr = decode_state-> next ())  {
2037-             const  auto  & ubatch = *ubatch_ptr ;
2044+         do  {
2045+             const  auto  & ubatch = kv_state-> get_ubatch () ;
20382046
20392047            n_outputs = ubatch.n_tokens ;
20402048
@@ -2073,7 +2081,7 @@ void llama_context::opt_epoch_iter(
20732081            ggml_free (ctx_compute_opt);
20742082
20752083            pos_batch += ubatch.n_tokens ;
2076-         }
2084+         }  while  (kv_state-> next ()); 
20772085    }
20782086}
20792087
0 commit comments