@@ -3676,6 +3676,20 @@ struct server_context {
36763676                        alora_disabled_id = enabled_loras[0 ];
36773677                    }
36783678
3679+                     bool  do_checkpoint = params_base.n_ctx_checkpoints  > 0 ;
3680+ 
3681+                     //  make a checkpoint of the parts of the memory that cannot be rolled back.
3682+                     //  checkpoints are created only if:
3683+                     //  - the model uses SWA and we are not using `swa_full`
3684+                     //  - the model architecture is marked as recurrent or hybrid
3685+                     // 
3686+                     //  TODO: try to make this conditional on the context or the memory module, instead of the model type
3687+                     do_checkpoint = do_checkpoint && (
3688+                             llama_model_is_recurrent (model) ||
3689+                             llama_model_is_hybrid (model) ||
3690+                             (llama_model_n_swa (model) > 0  && !params_base.swa_full )
3691+                             );
3692+ 
36793693                    //  add prompt tokens for processing in the current batch
36803694                    while  (slot.n_past  < slot.n_prompt_tokens  && batch.n_tokens  < n_batch) {
36813695                        //  get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
37003714
37013715                        slot.n_prompt_tokens_processed ++;
37023716                        slot.n_past ++;
3717+ 
3718+                         //  process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
3719+                         if  (do_checkpoint && slot.n_prompt_tokens  - slot.n_past  == 64 ) {
3720+                             break ;
3721+                         }
37033722                    }
37043723
37053724                    //  SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
37303749                        slot.i_batch    = batch.n_tokens  - 1 ;
37313750
37323751                        SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " n_past , batch.n_tokens );
3752+ 
3753+                         const  auto  pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
3754+                         const  auto  pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id );
3755+ 
3756+                         //  no need for empty or small checkpoints
3757+                         do_checkpoint = do_checkpoint && (pos_min >= 0  && pos_max >= 64 );
3758+ 
3759+                         //  no need to create checkpoints that are too close together
3760+                         do_checkpoint = do_checkpoint && (slot.ctx_checkpoints .empty () || pos_max > slot.ctx_checkpoints .back ().pos_max  + 64 );
3761+ 
3762+                         if  (do_checkpoint) {
3763+                             while  (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3764+                                 //  make room for the new checkpoint, if needed
3765+                                 const  auto  & cur = slot.ctx_checkpoints .front ();
3766+                                 SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3767+                                         cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3768+ 
3769+                                 slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3770+                             }
3771+ 
3772+                             const  size_t  checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3773+ 
3774+                             auto  & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3775+                                 /* .pos_min = */ 
3776+                                 /* .pos_max = */ 
3777+                                 /* .data    = */ uint8_t >(checkpoint_size),
3778+                             });
3779+ 
3780+                             llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3781+ 
3782+                             SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3783+                                     (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3784+                         }
37333785                    }
37343786                }
37353787
@@ -3853,40 +3905,6 @@ struct server_context {
38533905
38543906                    //  prompt evaluated for next-token prediction
38553907                    slot.state  = SLOT_STATE_GENERATING;
3856- 
3857-                     //  make a checkpoint of the parts of the memory that cannot be rolled back.
3858-                     //  checkpoints are created only if:
3859-                     //  - the model uses SWA and we are not using `swa_full`
3860-                     //  - the model architecture is marked as recurrent or hybrid
3861-                     // 
3862-                     //  TODO: try to make this conditional on the context or the memory module, instead of the model type
3863-                     const  bool  do_checkpoint =
3864-                         (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865-                         (llama_model_n_swa (model) > 0  && !params_base.swa_full );
3866- 
3867-                     if  (do_checkpoint && params_base.n_ctx_checkpoints  > 0 ) {
3868-                         while  (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869-                             //  make room for the new checkpoint, if needed
3870-                             const  auto  & cur = slot.ctx_checkpoints .front ();
3871-                             SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3872-                                     cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3873- 
3874-                             slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3875-                         }
3876- 
3877-                         const  size_t  checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3878- 
3879-                         auto  & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3880-                             /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
3881-                             /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3882-                             /* .data    = */ uint8_t >(checkpoint_size),
3883-                         });
3884- 
3885-                         llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3886- 
3887-                         SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " 
3888-                                 (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024  / 1024 );
3889-                     }
38903908                } else  if  (slot.state  != SLOT_STATE_GENERATING) {
38913909                    continue ; //  continue loop of slots
38923910                }
0 commit comments