11#include  " llama-batch.h" 
22
3+ #include  " llama-impl.h" 
4+ #include  " llama-cparams.h" 
5+ #include  " llama-vocab.h" 
6+ 
37#include  < cassert> 
48#include  < cstring> 
59#include  < algorithm> 
@@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
279283            );
280284}
281285
282- llama_batch_allocr::llama_batch_allocr (struct  llama_batch  in_batch, llama_pos p0) {
283-     batch = in_batch;
286+ llama_batch_allocr::llama_batch_allocr () = default;
287+ 
288+ bool  llama_batch_allocr::init (const  llama_batch & batch_inp, const  llama_vocab & vocab, llama_pos p0) {
289+     clear ();
290+ 
291+     batch = batch_inp;
292+ 
284293    GGML_ASSERT (batch.n_tokens  > 0 );
294+ 
295+     if  (!batch.pos ) {
296+         if  (batch.seq_id ) {
297+             LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " 
298+             return  false ;
299+         }
300+     }
301+ 
302+     if  (batch.token ) {
303+         for  (int32_t  i = 0 ; i < batch.n_tokens ; ++i) {
304+             if  (batch.token [i] < 0  || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
305+                 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " token [i]);
306+                 return  false ;
307+             }
308+         }
309+     }
310+ 
311+     if  (batch.seq_id ) {
312+         for  (int32_t  i = 0 ; i < batch.n_tokens ; ++i) {
313+             for  (int32_t  s = 0 ; s < batch.n_seq_id [i]; ++s) {
314+                 if  (batch.seq_id  && (batch.seq_id [i][s] < 0  || batch.seq_id [i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
315+                     LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " seq_id [i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
316+                     return  false ;
317+                 }
318+             }
319+         }
320+     }
321+ 
285322    if  (!batch.pos ) {
286323        assert (p0 >= 0 );
287324        pos.resize (batch.n_tokens );
@@ -290,13 +327,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
290327        }
291328        batch.pos  = pos.data ();
292329    }
330+ 
293331    if  (!batch.n_seq_id ) {
294332        n_seq_id.resize (batch.n_tokens );
295333        for  (int32_t  i = 0 ; i < batch.n_tokens ; i++) {
296334            n_seq_id[i] = seq_id_0.size ();
297335        }
298336        batch.n_seq_id  = n_seq_id.data ();
299337    }
338+ 
300339    if  (!batch.seq_id ) {
301340        seq_id.resize (batch.n_tokens  + 1 );
302341        seq_id[batch.n_tokens ] = NULL ;
@@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
305344        }
306345        batch.seq_id  = seq_id.data ();
307346    }
347+ 
308348    if  (!batch.logits ) {
309349        //  by default return the output only for the last token
310350        output.resize (batch.n_tokens );
311351        output[output.size () - 1 ] = true ;
312352        batch.logits  = output.data ();
313353    }
354+ 
355+     for  (int32_t  i = 0 ; i < batch.n_tokens ; ++i) {
356+         n_outputs += batch.logits [i] != 0 ;
357+     }
358+ 
359+     return  true ;
360+ }
361+ 
362+ const  llama_batch & llama_batch_allocr::get_batch () const  {
363+     return  batch;
364+ }
365+ 
366+ uint32_t  llama_batch_allocr::get_n_outputs () const  {
367+     return  n_outputs;
368+ }
369+ 
370+ void  llama_batch_allocr::clear () {
371+     n_outputs = 0 ;
372+ 
373+     batch = {};
374+     pos.clear ();
375+     n_seq_id.clear ();
376+     seq_id.clear ();
377+     output.clear ();
314378}
315379
316380// 
0 commit comments