@@ -999,6 +999,8 @@ struct winogrande_entry {
999999 size_t i_logits;
10001000 size_t common_prefix;
10011001 size_t required_tokens;
1002+ size_t n_base1; // number of tokens for context + choice 1
1003+ size_t n_base2; // number of tokens for context + choice 2
10021004 std::vector<llama_token> seq_tokens[2 ];
10031005};
10041006
@@ -1038,38 +1040,6 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
10381040 auto choice2 = line.substr (comma_pos[2 ]+1 , comma_pos[3 ] - comma_pos[2 ] - 1 );
10391041 auto answer = line.substr (comma_pos[3 ]+1 , line.size () - comma_pos[3 ] - 1 );
10401042 auto index = line.substr (0 , comma_pos[0 ]);
1041- if (' a' <= sentence[0 ] && sentence[0 ] <= ' z' ) {
1042- // make the first letter a capital letter
1043- sentence[0 ] -= ' a' - ' A' ;
1044- }
1045- for (int i = 0 ; i < (int ) sentence.size () - 1 ; ++i) {
1046- // trim repeated spaces and spaces before punctuation
1047- if (sentence[i] == ' ' ) {
1048- char next = sentence[i+1 ];
1049- if (next == ' ' || next == ' ,' || next == ' .' || next == ' \' ' ) {
1050- char r[2 ] = { next, 0 };
1051- sentence.replace (i, 2 , r);
1052- --i; // stay at the same index for repeated spaces
1053- }
1054- } else if (sentence[i] == ' ,' || sentence[i] == ' .' ) {
1055- if (sentence[i] == sentence[i+1 ]) {
1056- // trim repeated punctuation (forward to work at the end of sentences)
1057- char r[2 ] = { sentence[i], 0 };
1058- sentence.replace (i, 2 , r);
1059- --i; // same index to then run the other checks on that punctuation
1060- } else if (0 < i && sentence[i-1 ] == sentence[i]) {
1061- // trim repeated punctuation (looks back to work with the space trim)
1062- char r[2 ] = { sentence[i], 0 };
1063- sentence.replace (i-1 , 2 , r);
1064- i -= 2 ; // go back because content was shifted
1065- } else if (sentence[i+1 ] != ' ' ) {
1066- // add missing space after punctuation
1067- // (since the loop stops before the end, this adds no trailing space)
1068- char r[3 ] = { sentence[i], ' ' , 0 };
1069- sentence.replace (i, 1 , r);
1070- }
1071- }
1072- }
10731043 int where = 0 ;
10741044 for ( ; where < int (sentence.size ()); ++where) {
10751045 if (sentence[where] == ' _' ) break ;
@@ -1106,6 +1076,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
11061076 */
11071077static void winogrande_score (llama_context * ctx, const gpt_params & params) {
11081078
1079+ constexpr int k_min_trailing_ctx = 3 ;
1080+
11091081 auto data = load_winogrande_from_csv (params.prompt );
11101082 if (data.empty ()) {
11111083 fprintf (stderr, " %s: no tasks\n " , __func__);
@@ -1150,11 +1122,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11501122 task.common_prefix ++;
11511123 }
11521124
1125+ // TODO: the last token of each of the sequences don't need to be evaluated
11531126 task.required_tokens = task.common_prefix +
11541127 task.seq_tokens [0 ].size () - task.common_prefix +
1155- task.seq_tokens [1 ].size () - task.common_prefix
1156- // the last tokens don't need to be evaluated
1157- - 2 ;
1128+ task.seq_tokens [1 ].size () - task.common_prefix ;
1129+
1130+ task.n_base1 = ::llama_tokenize (ctx, task.first + task.choices [0 ], add_bos).size ();
1131+ task.n_base2 = ::llama_tokenize (ctx, task.first + task.choices [1 ], add_bos).size ();
11581132 }
11591133
11601134 fprintf (stderr, " %s : calculating winogrande score over selected tasks.\n " , __func__);
@@ -1201,8 +1175,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
12011175 n_logits += 1 ;
12021176
12031177 for (int s = 0 ; s < 2 ; ++s) {
1204- // end before the last token, no need to predict past the end of the sequences
1205- for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size () - 1 ; ++i) {
1178+ // TODO: end before the last token, no need to predict past the end of the sequences
1179+ for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
12061180 llama_batch_add (batch, data[i1].seq_tokens [s][i], i, { s0 + s }, true );
12071181 n_logits += 1 ;
12081182 }
@@ -1234,38 +1208,49 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
12341208 for (size_t i = i0; i < i1; ++i) {
12351209 auto & task = data[i];
12361210
1237- // start from the end of the common prefix
1238- size_t li = 0 ;
1239- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [0 ].size ()-1 ; ++j) {
1211+ const bool skip_choice =
1212+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
1213+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
1214+
1215+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
1216+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
1217+ size_t li = n_base1 - task.common_prefix ;
1218+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
12401219 eval_pairs.emplace_back (task.i_logits + li++, task.seq_tokens [0 ][j+1 ]);
12411220 }
1242- // first token of the second choice is predicted by the end of the common prefix
1243- eval_pairs.emplace_back (task.i_logits , task.seq_tokens [1 ][task.common_prefix ]);
1244- for (size_t j = task.common_prefix ; j < task.seq_tokens [1 ].size ()-1 ; ++j) {
1221+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
1222+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
1223+ // FIXME: this uses the wrong first logits when not skipping the choice word
1224+ li = task.seq_tokens [0 ].size () - task.common_prefix + n_base2 - task.common_prefix ;
1225+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
12451226 eval_pairs.emplace_back (task.i_logits + li++, task.seq_tokens [1 ][j+1 ]);
12461227 }
1247- if (i < i1 - 1 ) {
1248- // make sure all logits have been processed as expected
1249- GGML_ASSERT (task.i_logits + li == data[i+1 ].i_logits );
1250- }
12511228 }
12521229 compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
12531230
12541231 size_t ir = 0 ;
12551232 for (size_t i = i0; i < i1; ++i) {
12561233 auto & task = data[i];
12571234
1235+ const bool skip_choice =
1236+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
1237+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
1238+
12581239 float score_1st = 0 ;
1259- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [0 ].size ()-1 ; ++j) {
1240+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
1241+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
1242+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
12601243 score_1st += eval_results[ir++];
12611244 }
1262- score_1st /= (task.seq_tokens [0 ].size () - task. common_prefix );
1245+ score_1st /= (task.seq_tokens [0 ].size () - n_base1 - last_1st );
12631246
12641247 float score_2nd = 0 ;
1265- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [1 ].size ()-1 ; ++j) {
1248+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
1249+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
1250+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
12661251 score_2nd += eval_results[ir++];
12671252 }
1268- score_2nd /= (task.seq_tokens [1 ].size () - task. common_prefix );
1253+ score_2nd /= (task.seq_tokens [1 ].size () - n_base2 - last_2nd );
12691254
12701255 int result = score_1st > score_2nd ? 1 : 2 ;
12711256
0 commit comments