11#include " llama-context.h"
22
33#include " llama-impl.h"
4+ #include " llama-batch.h"
45#include " llama-io.h"
56#include " llama-memory.h"
67#include " llama-mmap.h"
1819llama_context::llama_context (
1920 const llama_model & model,
2021 llama_context_params params) :
21- model(model) {
22+ model(model),
23+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
2224 LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
2325
2426 t_start_us = model.t_start_us ;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494496}
495497
496498float * llama_context::get_logits_ith (int32_t i) {
497- int32_t j = -1 ;
499+ int64_t j = -1 ;
498500
499501 try {
500502 if (logits == nullptr ) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517519 }
518520 if (j >= n_outputs) {
519521 // This should not happen
520- throw std::runtime_error (format (" corrupt output buffer (j=%d , n_outputs=%d)" , j, n_outputs));
522+ throw std::runtime_error (format (" corrupt output buffer (j=%" PRId64 " , n_outputs=%d)" , j, n_outputs));
521523 }
522524
523525 return logits + j*model.vocab .n_tokens ();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536538}
537539
538540float * llama_context::get_embeddings_ith (int32_t i) {
539- int32_t j = -1 ;
541+ int64_t j = -1 ;
540542
541543 try {
542544 if (embd == nullptr ) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559561 }
560562 if (j >= n_outputs) {
561563 // This should not happen
562- throw std::runtime_error (format (" corrupt output buffer (j=%d , n_outputs=%d)" , j, n_outputs));
564+ throw std::runtime_error (format (" corrupt output buffer (j=%" PRId64 " , n_outputs=%d)" , j, n_outputs));
563565 }
564566
565567 return embd + j*model.hparams .n_embd ;
@@ -727,18 +729,19 @@ int llama_context::encode(llama_batch & inp_batch) {
727729
728730 // temporary allocate memory for the input batch if needed
729731 // note: during encode, we always pass the full sequence starting from pos = 0
730- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : 0 );
732+ batch_allocr-> init (inp_batch, inp_batch.pos ? -1 : 0 );
731733
732- const llama_batch & batch = batch_allocr.batch ;
733- const int32_t n_tokens = batch.n_tokens ;
734+ const llama_batch & batch = batch_allocr->get_batch ();
735+
736+ const uint32_t n_tokens = batch.n_tokens ;
734737
735738 const auto & hparams = model.hparams ;
736739
737740 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
738741
739742 // TODO: move the validation to the llama_batch_allocr
740743 if (batch.token ) {
741- for (int32_t i = 0 ; i < n_tokens; ++i) {
744+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
742745 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
743746 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
744747 return -1 ;
@@ -775,7 +778,7 @@ int llama_context::encode(llama_batch & inp_batch) {
775778 return -2 ;
776779 };
777780
778- for (int32_t i = 0 ; i < n_tokens; ++i) {
781+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
779782 output_ids[i] = i;
780783 }
781784
@@ -831,7 +834,8 @@ int llama_context::encode(llama_batch & inp_batch) {
831834
832835 GGML_ASSERT (!ubatch.equal_seqs ); // TODO: handle equal splits
833836
834- for (int32_t i = 0 ; i < n_tokens; i++) {
837+ // TODO: fix sequence indexing
838+ for (uint32_t i = 0 ; i < n_tokens; i++) {
835839 const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
836840 if (embd_seq_out.find (seq_id) != embd_seq_out.end ()) {
837841 continue ;
@@ -881,7 +885,7 @@ int llama_context::encode(llama_batch & inp_batch) {
881885 // TODO: the seuqence indexing here is likely not correct in the general case
882886 // probably works only for split_simple
883887 cross.seq_ids_enc .resize (n_tokens);
884- for (int32_t i = 0 ; i < n_tokens; i++) {
888+ for (uint32_t i = 0 ; i < n_tokens; i++) {
885889 cross.seq_ids_enc [i].clear ();
886890 for (int s = 0 ; s < ubatch.n_seq_id [i]; s++) {
887891 llama_seq_id seq_id = ubatch.seq_id [i][s];
@@ -912,30 +916,30 @@ int llama_context::decode(llama_batch & inp_batch) {
912916 }
913917
914918 // temporary allocate memory for the input batch if needed
915- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
919+ batch_allocr-> init (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
916920
917- const llama_batch & batch = batch_allocr. batch ;
921+ const llama_batch & batch = batch_allocr-> get_batch () ;
918922
919923 const auto & vocab = model.vocab ;
920924 const auto & hparams = model.hparams ;
921925
922926 const int32_t n_vocab = vocab.n_tokens ();
927+ const int64_t n_embd = hparams.n_embd ;
923928
924- const int64_t n_tokens_all = batch.n_tokens ;
925- const int64_t n_embd = hparams.n_embd ;
929+ const uint32_t n_tokens_all = batch.n_tokens ;
926930
927931 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
928932
929933 // TODO: move the validation to the llama_batch_allocr
930934 if (batch.token ) {
931- for (int64_t i = 0 ; i < n_tokens_all; ++i) {
935+ for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
932936 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
933- LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
937+ LLAMA_LOG_ERROR (" %s: invalid token[%d ] = %d\n " , __func__, i, batch.token [i]);
934938 return -1 ;
935939 }
936940
937941 if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
938- LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
942+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
939943 return -1 ;
940944 }
941945 }
@@ -944,7 +948,7 @@ int llama_context::decode(llama_batch & inp_batch) {
944948 // this indicates we are doing pooled embedding
945949 const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946950
947- int64_t n_outputs_all = 0 ;
951+ uint32_t n_outputs_all = 0 ;
948952
949953 // count outputs
950954 for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
@@ -954,7 +958,7 @@ int llama_context::decode(llama_batch & inp_batch) {
954958 if (embd_pooled) {
955959 // require that all tokens are output
956960 if (n_outputs_all != n_tokens_all) {
957- LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
961+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d , n_tokens_all = %d )\n " ,
958962 __func__, n_outputs_all, n_tokens_all);
959963 return -1 ;
960964 }
@@ -1024,7 +1028,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10241028
10251029 // reserve output buffer
10261030 if (output_reserve (n_outputs_all) < n_outputs_all) {
1027- LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
1031+ LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
10281032 return -2 ;
10291033 };
10301034
@@ -1063,6 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10631067 pos_min[s] = std::numeric_limits<llama_pos>::max ();
10641068 }
10651069
1070+ // TODO: fix sequence indexing
10661071 for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
10671072 const auto & seq_id = ubatch.seq_id [i][0 ];
10681073
@@ -1176,14 +1181,14 @@ int llama_context::decode(llama_batch & inp_batch) {
11761181 n_outputs = n_outputs_all;
11771182
11781183 // set output mappings
1179- {
1184+ if (n_outputs > 0 ) {
11801185 bool sorted_output = true ;
11811186
11821187 auto & out_ids = mstate->out_ids ();
11831188
1184- GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all );
1189+ GGML_ASSERT (out_ids.size () == (size_t ) n_outputs );
11851190
1186- for (int64_t i = 0 ; i < n_outputs_all ; ++i) {
1191+ for (int64_t i = 0 ; i < n_outputs ; ++i) {
11871192 int64_t out_id = out_ids[i];
11881193 output_ids[out_id] = i;
11891194 if (out_id != i) {
@@ -1195,20 +1200,22 @@ int llama_context::decode(llama_batch & inp_batch) {
11951200 // note: this is mostly relevant for recurrent models atm
11961201 if (!sorted_output) {
11971202 const uint32_t n_vocab = model.vocab .n_tokens ();
1198- const uint32_t n_embd = model.hparams .n_embd ;
1203+ const uint64_t n_embd = model.hparams .n_embd ;
11991204
12001205 GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
12011206
12021207 // TODO: is there something more efficient which also minimizes swaps?
12031208 // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1204- for (int32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1205- int32_t j_min = i;
1206- for (int32_t j = i + 1 ; j < n_outputs; ++j) {
1209+ for (uint32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1210+ uint32_t j_min = i;
1211+ for (uint32_t j = i + 1 ; j < n_outputs; ++j) {
12071212 if (out_ids[j] < out_ids[j_min]) {
12081213 j_min = j;
12091214 }
12101215 }
1211- if (j_min == i) { continue ; }
1216+ if (j_min == i) {
1217+ continue ;
1218+ }
12121219 std::swap (out_ids[i], out_ids[j_min]);
12131220 if (logits_size > 0 ) {
12141221 for (uint32_t k = 0 ; k < n_vocab; k++) {
@@ -1221,8 +1228,10 @@ int llama_context::decode(llama_batch & inp_batch) {
12211228 }
12221229 }
12231230 }
1231+
12241232 std::fill (output_ids.begin (), output_ids.end (), -1 );
1225- for (int32_t i = 0 ; i < n_outputs; ++i) {
1233+
1234+ for (uint32_t i = 0 ; i < n_outputs; ++i) {
12261235 output_ids[out_ids[i]] = i;
12271236 }
12281237 }
@@ -1242,7 +1251,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12421251// output
12431252//
12441253
1245- int32_t llama_context::output_reserve (int32_t n_outputs) {
1254+ uint32_t llama_context::output_reserve (int32_t n_outputs) {
12461255 const auto & hparams = model.hparams ;
12471256 const auto & vocab = model.vocab ;
12481257
@@ -1308,8 +1317,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
13081317 // set all ids as invalid (negative)
13091318 std::fill (output_ids.begin (), output_ids.end (), -1 );
13101319
1311- this ->n_outputs = 0 ;
1312- this ->n_outputs_max = n_outputs_max;
1320+ this ->n_outputs = 0 ;
13131321
13141322 return n_outputs_max;
13151323}
@@ -1800,14 +1808,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
18001808
18011809 std::vector<int32_t > w_output_pos;
18021810
1803- GGML_ASSERT (n_outputs <= n_outputs_max);
1804-
18051811 w_output_pos.resize (n_outputs);
18061812
18071813 // build a more compact representation of the output ids
18081814 for (size_t i = 0 ; i < n_batch (); ++i) {
18091815 // map an output id to a position in the batch
1810- int32_t pos = output_ids[i];
1816+ int64_t pos = output_ids[i];
18111817 if (pos >= 0 ) {
18121818 GGML_ASSERT (pos < n_outputs);
18131819 w_output_pos[pos] = i;
@@ -2082,7 +2088,7 @@ void llama_context::opt_epoch_iter(
20822088
20832089 embd_seq.clear ();
20842090
2085- int64_t n_outputs_all = n_tokens_all;
2091+ uint32_t n_outputs_all = n_tokens_all;
20862092
20872093 auto mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled);
20882094 if (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
@@ -2092,7 +2098,7 @@ void llama_context::opt_epoch_iter(
20922098
20932099 // reserve output buffer
20942100 if (output_reserve (n_outputs_all) < n_outputs_all) {
2095- LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
2101+ LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
20962102 GGML_ABORT (" TODO: handle this error" );
20972103 };
20982104
0 commit comments