Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ struct slot_params {
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;

struct common_params_sampling sampling;
struct common_params_speculative speculative;
Expand Down Expand Up @@ -441,7 +440,6 @@ struct server_task {

{
params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false);

const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
Expand Down Expand Up @@ -472,6 +470,16 @@ struct server_task {
}
}
}

params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
//SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
}
Comment on lines +476 to +481
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is done for every token during generation, I suspect it is going to have a significant performance impact.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done once per completion request, at the beginning, upon processing the input json parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If performance is a concern we could maybe provide a list of EoG tokens to iterate over instead of iterating over all tokens and checking whether each one is EoG. Although I think iterating over all tokens once per request is going to be negligible vs. iterating over all tokens once per generated token as is being done for sampling.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my system this takes about 0.3 ms for a 150k vocab model, so I suppose it is not that bad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this anyway: #14721

}
}

{
Expand Down Expand Up @@ -2217,10 +2225,6 @@ struct server_context {
slot.params.n_predict = slot.n_predict;
}

if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
}

{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
Expand Down
Loading