@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
168168                  bool  is_resampling) {  //  Add a parameter to indicate if we are resampling
169169    const  llama_sampling_params & params = ctx_sampling->params ;
170170
171-     const  int  n_vocab = llama_n_vocab (llama_get_model (ctx_main));
172- 
173171    const  float    temp            = params.temp ;
174-     const  int32_t  penalty_last_n  = params.penalty_last_n  < 0  ? params.n_prev  : params.penalty_last_n ;
175-     const  float    penalty_repeat  = params.penalty_repeat ;
176-     const  float    penalty_freq    = params.penalty_freq ;
177-     const  float    penalty_present = params.penalty_present ;
178172    const  int      mirostat        = params.mirostat ;
179173    const  float    mirostat_tau    = params.mirostat_tau ;
180174    const  float    mirostat_eta    = params.mirostat_eta ;
181-     const  bool     penalize_nl     = params.penalize_nl ;
182175
183-     auto  & prev = ctx_sampling->prev ;
184-     auto  & cur  = ctx_sampling->cur ;
185- 
186-     llama_token id = 0 ;
187- 
188-     //  Get a pointer to the logits
189-     float  * logits = llama_get_logits_ith (ctx_main, idx);
190- 
191-     //  Declare original_logits at the beginning of the function scope
192176    std::vector<float > original_logits;
193- 
177+      auto  cur_p =  llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); 
194178    if  (!is_resampling) {
195-         //  Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
196-         original_logits = std::vector<float >(logits, logits + llama_n_vocab (llama_get_model (ctx_main)));
197-     }
198- 
199-     //  apply params.logit_bias map
200-     for  (auto  it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
201-         logits[it->first ] += it->second ;
202-     }
203- 
204-     if  (ctx_cfg) {
205-         float  * logits_guidance = llama_get_logits_ith (ctx_cfg, idx);
206-         llama_sample_apply_guidance (ctx_main, logits, logits_guidance, params.cfg_scale );
207-     }
208- 
209-     cur.clear ();
210- 
211-     for  (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
212-         cur.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
213-     }
214- 
215-     llama_token_data_array cur_p = { cur.data (), cur.size (), false  };
216- 
217-     //  apply penalties
218-     const  auto & penalty_tokens = params.use_penalty_prompt_tokens  ? params.penalty_prompt_tokens  : prev;
219-     const  int  penalty_tokens_used_size = std::min ((int )penalty_tokens.size (), penalty_last_n);
220-     if  (penalty_tokens_used_size) {
221-         const  float  nl_logit = logits[llama_token_nl (llama_get_model (ctx_main))];
222- 
223-         llama_sample_repetition_penalties (ctx_main, &cur_p,
224-                 penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
225-                 penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
226- 
227-         if  (!penalize_nl) {
228-             for  (size_t  idx = 0 ; idx < cur_p.size ; idx++) {
229-                 if  (cur_p.data [idx].id  == llama_token_nl (llama_get_model (ctx_main))) {
230-                     cur_p.data [idx].logit  = nl_logit;
231-                     break ;
232-                 }
233-             }
234-         }
235-     }
236- 
237-     //  If we are in the resampling phase, apply grammar checks before sampling logic
238-     if  (is_resampling && ctx_sampling->grammar  != NULL ) {
239-         llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
179+         GGML_ASSERT (!original_logits.empty ());
240180    }
181+     llama_token id = 0 ;
182+     //  Get a pointer to the logits
183+     float  * logits = llama_get_logits_ith (ctx_main, idx);
241184
242185    if  (temp < 0.0 ) {
243186        //  greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
302245    return  id;
303246}
304247
305- static  llama_token_data_array llama_sample_probability_distribution_impl (
248+ static  llama_token_data_array llama_sampling_prepare_impl (
306249                  struct  llama_sampling_context  * ctx_sampling,
307250                  struct  llama_context  * ctx_main,
308251                  struct  llama_context  * ctx_cfg,
309-                   const  int  idx) {
252+                   const  int  idx,
253+                   bool  apply_grammar,
254+                   std::vector<float > * original_logits) {
310255    const  llama_sampling_params & params = ctx_sampling->params ;
311256
312257    const  int  n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
315260    const  float    penalty_repeat  = params.penalty_repeat ;
316261    const  float    penalty_freq    = params.penalty_freq ;
317262    const  float    penalty_present = params.penalty_present ;
263+ 
318264    const  bool     penalize_nl     = params.penalize_nl ;
319265
320266    auto  & prev = ctx_sampling->prev ;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
323269    //  Get a pointer to the logits
324270    float  * logits = llama_get_logits_ith (ctx_main, idx);
325271
326-     //  Declare original_logits at the beginning of the function scope
327-     std::vector<float > original_logits;
272+     if  (apply_grammar && original_logits != NULL ) {
273+         //  Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
274+         *original_logits = {logits, logits + llama_n_vocab (llama_get_model (ctx_main))};
275+     }
328276
329277    //  apply params.logit_bias map
330278    for  (auto  it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
364312        }
365313    }
366314
367-     //  apply grammar checks
368-     if  (ctx_sampling->grammar  != NULL ) {
315+     //  apply grammar checks before sampling logic 
316+     if  (apply_grammar &&  ctx_sampling->grammar  != NULL ) {
369317        llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
370318    }
371319
372-     llama_sample_softmax (ctx_main, &cur_p);
373320    return  cur_p;
374321}
375322
@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
382329    return  llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
383330}
384331
385- llama_token_data_array llama_sampling_probability_distribution (
332+ llama_token_data_array llama_sampling_prepare (
386333                  struct  llama_sampling_context  * ctx_sampling,
387334                  struct  llama_context  * ctx_main,
388335                  struct  llama_context  * ctx_cfg,
389-                   const  int  idx) {
390-     return  llama_sample_probability_distribution_impl (ctx_sampling,ctx_main, ctx_cfg, idx);
336+                   const  int  idx,
337+                   bool  apply_grammar,
338+                   std::vector<float > * original_logits) {
339+     return  llama_sampling_prepare_impl (ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
391340}
392341
393342void  llama_sampling_accept (
0 commit comments