11#include " common.h"
22#include " llama.h"
33#include " build-info.h"
4+ #include " grammar-parser.h"
45
56#ifndef NDEBUG
67// crash the server in debug mode, otherwise send an http 500 error
@@ -195,6 +196,8 @@ struct llama_server_context
195196 llama_context *ctx = nullptr ;
196197 gpt_params params;
197198
199+ llama_grammar *grammar = nullptr ;
200+
198201 bool truncated = false ;
199202 bool stopped_eos = false ;
200203 bool stopped_word = false ;
@@ -226,6 +229,7 @@ struct llama_server_context
226229 void rewind ()
227230 {
228231 params.antiprompt .clear ();
232+ params.grammar .clear ();
229233 num_prompt_tokens = 0 ;
230234 num_tokens_predicted = 0 ;
231235 generated_text = " " ;
@@ -237,6 +241,7 @@ struct llama_server_context
237241 stopped_limit = false ;
238242 stopping_word = " " ;
239243 multibyte_pending = 0 ;
244+ grammar = nullptr ;
240245
241246 n_remain = 0 ;
242247 n_past = 0 ;
@@ -257,6 +262,33 @@ struct llama_server_context
257262 return true ;
258263 }
259264
265+ bool loadGrammar ()
266+ {
267+ if (!params.grammar .empty ()) {
268+ grammar_parser::parse_state parsed_grammar;
269+
270+ parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
271+ // will be empty (default) if there are parse errors
272+ if (parsed_grammar.rules .empty ()) {
273+ LOG_ERROR (" grammar parse error" , {{" grammar" , params.grammar }});
274+ return false ;
275+ }
276+ grammar_parser::print_grammar (stderr, parsed_grammar);
277+
278+ {
279+ auto it = params.logit_bias .find (llama_token_eos ());
280+ if (it != params.logit_bias .end () && it->second == -INFINITY) {
281+ LOG_WARNING (" EOS token is disabled, which will cause most grammars to fail" , {});
282+ }
283+ }
284+
285+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
286+ grammar = llama_grammar_init (
287+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
288+ }
289+ return true ;
290+ }
291+
260292 void loadPrompt ()
261293 {
262294 params.prompt .insert (0 , 1 , ' ' ); // always add a first space
@@ -420,6 +452,10 @@ struct llama_server_context
420452 logits[llama_token_nl ()] = nl_logit;
421453 }
422454
455+ if (grammar != nullptr ) {
456+ llama_sample_grammar (ctx, &candidates_p, grammar);
457+ }
458+
423459 if (temp <= 0 )
424460 {
425461 // Greedy sampling
@@ -457,10 +493,15 @@ struct llama_server_context
457493 }
458494 }
459495
496+ if (grammar != nullptr ) {
497+ llama_grammar_accept_token (ctx, grammar, result.tok );
498+ }
499+
460500 for (size_t i = 0 ; i < std::min (candidates_p.size , (size_t )n_probs); ++i)
461501 {
462502 result.probs .push_back ({candidates_p.data [i].id , candidates_p.data [i].p });
463503 }
504+
464505 last_n_tokens.erase (last_n_tokens.begin ());
465506 last_n_tokens.push_back (result.tok );
466507 num_tokens_predicted++;
@@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
947988 {" stream" , llama.stream },
948989 {" logit_bias" , llama.params .logit_bias },
949990 {" n_probs" , llama.params .n_probs },
991+ {" grammar" , llama.params .grammar },
950992 };
951993}
952994
@@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
10481090 llama.params .n_keep = body.value (" n_keep" , default_params.n_keep );
10491091 llama.params .seed = body.value (" seed" , default_params.seed );
10501092 llama.params .prompt = body.value (" prompt" , default_params.prompt );
1093+ llama.params .grammar = body.value (" grammar" , default_params.grammar );
10511094 llama.params .n_probs = body.value (" n_probs" , default_params.n_probs );
10521095
10531096 llama.params .logit_bias .clear ();
@@ -1179,6 +1222,12 @@ int main(int argc, char **argv)
11791222
11801223 parse_options_completion (json::parse (req.body ), llama);
11811224
1225+ if (!llama.loadGrammar ())
1226+ {
1227+ res.status = 400 ;
1228+ return ;
1229+ }
1230+
11821231 llama.loadPrompt ();
11831232 llama.beginCompletion ();
11841233
@@ -1334,8 +1383,12 @@ int main(int argc, char **argv)
13341383
13351384 svr.set_error_handler ([](const Request &, Response &res)
13361385 {
1337- res.set_content (" File Not Found" , " text/plain" );
1338- res.status = 404 ; });
1386+ if (res.status == 400 ) {
1387+ res.set_content (" Invalid request" , " text/plain" );
1388+ } else {
1389+ res.set_content (" File Not Found" , " text/plain" );
1390+ res.status = 404 ;
1391+ } });
13391392
13401393 // set timeouts and change hostname and port
13411394 svr.set_read_timeout (sparams.read_timeout );
@@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
13631416 return 1 ;
13641417 }
13651418
1419+ if (llama.grammar != nullptr ) {
1420+ llama_grammar_free (llama.grammar );
1421+ }
13661422 llama_backend_free ();
13671423
13681424 return 0 ;
0 commit comments