1- #include < vector>
2- #include < cstdio>
3- #include < chrono>
4-
51#include " common.h"
62#include " llama.h"
7- #include " llama.cpp"
83
9- using namespace std ;
4+ #include < vector>
5+ #include < cstdio>
6+ #include < chrono>
107
118int main (int argc, char ** argv) {
129 gpt_params params;
@@ -20,21 +17,25 @@ int main(int argc, char ** argv) {
2017 return 1 ;
2118 }
2219
20+ if (params.n_predict < 0 ) {
21+ params.n_predict = 16 ;
22+ }
23+
2324 auto lparams = llama_context_default_params ();
2425
25- lparams.n_ctx = params.n_ctx ;
26- lparams.n_parts = params.n_parts ;
27- lparams.seed = params.seed ;
28- lparams.f16_kv = params.memory_f16 ;
29- lparams.use_mmap = params.use_mmap ;
30- lparams.use_mlock = params.use_mlock ;
26+ lparams.n_ctx = params.n_ctx ;
27+ lparams.n_parts = params.n_parts ;
28+ lparams.seed = params.seed ;
29+ lparams.f16_kv = params.memory_f16 ;
30+ lparams.use_mmap = params.use_mmap ;
31+ lparams.use_mlock = params.use_mlock ;
3132
3233 auto n_past = 0 ;
33- auto last_n_tokens_data = vector<llama_token>(params.repeat_last_n , 0 );
34+ auto last_n_tokens_data = std:: vector<llama_token>(params.repeat_last_n , 0 );
3435
3536 // init
3637 auto ctx = llama_init_from_file (params.model .c_str (), lparams);
37- auto tokens = vector<llama_token>(params.n_ctx );
38+ auto tokens = std:: vector<llama_token>(params.n_ctx );
3839 auto n_prompt_tokens = llama_tokenize (ctx, params.prompt .c_str (), tokens.data (), tokens.size (), true );
3940
4041 if (n_prompt_tokens < 1 ) {
@@ -43,26 +44,29 @@ int main(int argc, char ** argv) {
4344 }
4445
4546 // evaluate prompt
46-
4747 llama_eval (ctx, tokens.data (), n_prompt_tokens, n_past, params.n_threads );
4848
4949 last_n_tokens_data.insert (last_n_tokens_data.end (), tokens.data (), tokens.data () + n_prompt_tokens);
5050 n_past += n_prompt_tokens;
5151
52+ const size_t state_size = llama_get_state_size (ctx);
53+ uint8_t * state_mem = new uint8_t [state_size];
54+
5255 // Save state (rng, logits, embedding and kv_cache) to file
53- FILE *fp_write = fopen ( " dump_state.bin " , " wb " );
54- auto state_size = llama_get_state_size (ctx );
55- auto state_mem = new uint8_t [state_size];
56- llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
57- fwrite (state_mem, 1 , state_size, fp_write);
58- fclose (fp_write);
56+ {
57+ FILE *fp_write = fopen ( " dump_state.bin " , " wb " );
58+ llama_copy_state_data (ctx, state_mem); // could also copy directly to memory mapped file
59+ fwrite (state_mem, 1 , state_size, fp_write);
60+ fclose ( fp_write);
61+ }
5962
6063 // save state (last tokens)
61- auto last_n_tokens_data_saved = vector<llama_token>(last_n_tokens_data);
62- auto n_past_saved = n_past;
64+ const auto last_n_tokens_data_saved = std:: vector<llama_token>(last_n_tokens_data);
65+ const auto n_past_saved = n_past;
6366
6467 // first run
6568 printf (" \n %s" , params.prompt .c_str ());
69+
6670 for (auto i = 0 ; i < params.n_predict ; i++) {
6771 auto logits = llama_get_logits (ctx);
6872 auto n_vocab = llama_n_vocab (ctx);
@@ -75,31 +79,42 @@ int main(int argc, char ** argv) {
7579 auto next_token = llama_sample_token (ctx, &candidates_p);
7680 auto next_token_str = llama_token_to_str (ctx, next_token);
7781 last_n_tokens_data.push_back (next_token);
82+
7883 printf (" %s" , next_token_str);
7984 if (llama_eval (ctx, &next_token, 1 , n_past, params.n_threads )) {
8085 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
8186 return 1 ;
8287 }
8388 n_past += 1 ;
8489 }
90+
8591 printf (" \n\n " );
8692
8793 // free old model
8894 llama_free (ctx);
8995
9096 // load new model
91-
9297 auto ctx2 = llama_init_from_file (params.model .c_str (), lparams);
9398
9499 // Load state (rng, logits, embedding and kv_cache) from file
95- FILE *fp_read = fopen (" dump_state.bin" , " rb" );
96- auto state_size2 = llama_get_state_size (ctx2);
97- if (state_size != state_size2) {
98- fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
100+ {
101+ FILE *fp_read = fopen (" dump_state.bin" , " rb" );
102+ if (state_size != llama_get_state_size (ctx2)) {
103+ fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
104+ return 1 ;
105+ }
106+
107+ const size_t ret = fread (state_mem, 1 , state_size, fp_read);
108+ if (ret != state_size) {
109+ fprintf (stderr, " \n %s : failed to read state\n " , __func__);
110+ return 1 ;
111+ }
112+
113+ llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
114+ fclose (fp_read);
99115 }
100- fread (state_mem, 1 , state_size, fp_read);
101- llama_set_state_data (ctx2, state_mem); // could also read directly from memory mapped file
102- fclose (fp_read);
116+
117+ delete[] state_mem;
103118
104119 // restore state (last tokens)
105120 last_n_tokens_data = last_n_tokens_data_saved;
@@ -118,13 +133,16 @@ int main(int argc, char ** argv) {
118133 auto next_token = llama_sample_token (ctx2, &candidates_p);
119134 auto next_token_str = llama_token_to_str (ctx2, next_token);
120135 last_n_tokens_data.push_back (next_token);
136+
121137 printf (" %s" , next_token_str);
122138 if (llama_eval (ctx2, &next_token, 1 , n_past, params.n_threads )) {
123139 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
124140 return 1 ;
125141 }
126142 n_past += 1 ;
127143 }
144+
128145 printf (" \n\n " );
146+
129147 return 0 ;
130148}
0 commit comments