diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8cb8d0033f7d9..1b4961d425122 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1321,17 +1321,24 @@ struct server_slot { && are_lora_equal(lora, other_slot.lora); } + // There are two caps on the budge of a single request: + // * [params.n_predict] + // * [global_params.n_predict] + // This function returns true if the request is not limited by either of them. bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } + n_remaining = INT32_MAX; - n_remaining = -1; + // The request or server have specified limits on the number of tokens to generate. + if ((params.n_predict >= 0) || (global_params.n_predict >= 0)) { + n_remaining = std::min(n_remaining, params.n_predict - n_decoded); + } - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; + // The request or server have limits based on the context window. + if (params.n_predict == -2 || global_params.n_predict == -2) { + n_remaining = std::min(n_remaining, n_ctx - n_decoded); } return n_remaining > 0; // no budget diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 0ed5b99bef4e4..532c08597c7ab 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -143,6 +143,7 @@ def test_consistent_result_same_seed(n_slots: int): def test_different_result_different_seed(n_slots: int): global server server.n_slots = n_slots + server.n_predict = -1 server.start() last_res = None for seed in range(4): @@ -150,6 +151,7 @@ def test_different_result_different_seed(n_slots: int): "prompt": "I believe the meaning of life is", "seed": seed, "temperature": 1.0, + "n_predict": -1, "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed }) if last_res is not None: @@ -426,3 +428,18 @@ def test_cancel_request(): time.sleep(1) # wait for HTTP_POLLING_SECONDS res = server.make_request("GET", "/slots") assert res.body[0]["is_processing"] == False + + +def test_context_window_sized_completion(): + server = ServerPreset.tinyllama2() + server.n_ctx = 16 + server.n_predict = -1 + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": -2, + "prompt": "The 50 states in the US are ", + }) + assert res.status_code == 200 + assert res.body["timings"]["predicted_n"] == server.n_ctx + assert res.body["stop_type"] == "limit" + assert type(res.body["has_new_line"]) == bool