Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c508e68

Browse files
fix: add cpu_threads to model.yaml (#1845)
Co-authored-by: vansangpfiev <[email protected]>
1 parent da7576d commit c508e68

File tree

7 files changed

+101
-74
lines changed

7 files changed

+101
-74
lines changed

engine/cli/command_line_parser.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) {
908908
"ngl",
909909
"ctx_len",
910910
"n_parallel",
911+
"cpu_threads",
911912
"engine",
912913
"prompt_template",
913914
"system_template",

engine/cli/commands/model_upd_cmd.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key,
228228
data["n_parallel"] = static_cast<int>(f);
229229
});
230230
}},
231+
{"cpu_threads",
232+
[this](Json::Value &data, const std::string& k, const std::string& v) {
233+
UpdateNumericField(k, v, [&data](float f) {
234+
data["cpu_threads"] = static_cast<int>(f);
235+
});
236+
}},
231237
{"tp",
232238
[this](Json::Value &data, const std::string& k, const std::string& v) {
233239
UpdateNumericField(k, v, [&data](float f) {

engine/config/model_config.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ struct ModelConfig {
164164
int ngl = std::numeric_limits<int>::quiet_NaN();
165165
int ctx_len = std::numeric_limits<int>::quiet_NaN();
166166
int n_parallel = 1;
167+
int cpu_threads = -1;
167168
std::string engine;
168169
std::string prompt_template;
169170
std::string system_template;
@@ -272,6 +273,8 @@ struct ModelConfig {
272273
ctx_len = json["ctx_len"].asInt();
273274
if (json.isMember("n_parallel"))
274275
n_parallel = json["n_parallel"].asInt();
276+
if (json.isMember("cpu_threads"))
277+
cpu_threads = json["cpu_threads"].asInt();
275278
if (json.isMember("engine"))
276279
engine = json["engine"].asString();
277280
if (json.isMember("prompt_template"))
@@ -362,6 +365,9 @@ struct ModelConfig {
362365
obj["ngl"] = ngl;
363366
obj["ctx_len"] = ctx_len;
364367
obj["n_parallel"] = n_parallel;
368+
if (cpu_threads > 0) {
369+
obj["cpu_threads"] = cpu_threads;
370+
}
365371
obj["engine"] = engine;
366372
obj["prompt_template"] = prompt_template;
367373
obj["system_template"] = system_template;
@@ -474,6 +480,8 @@ struct ModelConfig {
474480
format_utils::MAGENTA);
475481
oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel),
476482
format_utils::MAGENTA);
483+
oss << format_utils::print_kv("cpu_threads", std::to_string(cpu_threads),
484+
format_utils::MAGENTA);
477485
if (ngl != std::numeric_limits<int>::quiet_NaN())
478486
oss << format_utils::print_kv("ngl", std::to_string(ngl),
479487
format_utils::MAGENTA);

engine/config/yaml_config.cc

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ void YamlHandler::ModelConfigFromYaml() {
119119
tmp.ctx_len = yaml_node_["ctx_len"].as<int>();
120120
if (yaml_node_["n_parallel"])
121121
tmp.n_parallel = yaml_node_["n_parallel"].as<int>();
122+
if (yaml_node_["cpu_threads"])
123+
tmp.cpu_threads = yaml_node_["cpu_threads"].as<int>();
122124
if (yaml_node_["tp"])
123125
tmp.tp = yaml_node_["tp"].as<int>();
124126
if (yaml_node_["stream"])
@@ -224,6 +226,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
224226
yaml_node_["ctx_len"] = model_config_.ctx_len;
225227
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
226228
yaml_node_["n_parallel"] = model_config_.n_parallel;
229+
if (!std::isnan(static_cast<double>(model_config_.cpu_threads)))
230+
yaml_node_["cpu_threads"] = model_config_.cpu_threads;
227231
if (!std::isnan(static_cast<double>(model_config_.tp)))
228232
yaml_node_["tp"] = model_config_.tp;
229233
if (!std::isnan(static_cast<double>(model_config_.stream)))
@@ -283,110 +287,112 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
283287
// Method to write all attributes to a YAML file
284288
void YamlHandler::WriteYamlFile(const std::string& file_path) const {
285289
try {
286-
std::ofstream outFile(file_path);
287-
if (!outFile) {
290+
std::ofstream out_file(file_path);
291+
if (!out_file) {
288292
throw std::runtime_error("Failed to open output file.");
289293
}
290294
// Write GENERAL GGUF METADATA
291-
outFile << "# BEGIN GENERAL GGUF METADATA\n";
292-
outFile << format_utils::writeKeyValue(
295+
out_file << "# BEGIN GENERAL GGUF METADATA\n";
296+
out_file << format_utils::WriteKeyValue(
293297
"id", yaml_node_["id"],
294298
"Model ID unique between models (author / quantization)");
295-
outFile << format_utils::writeKeyValue(
299+
out_file << format_utils::WriteKeyValue(
296300
"model", yaml_node_["model"],
297301
"Model ID which is used for request construct - should be "
298302
"unique between models (author / quantization)");
299-
outFile << format_utils::writeKeyValue("name", yaml_node_["name"],
303+
out_file << format_utils::WriteKeyValue("name", yaml_node_["name"],
300304
"metadata.general.name");
301305
if (yaml_node_["version"]) {
302-
outFile << "version: " << yaml_node_["version"].as<std::string>() << "\n";
306+
out_file << "version: " << yaml_node_["version"].as<std::string>() << "\n";
303307
}
304308
if (yaml_node_["files"] && yaml_node_["files"].size()) {
305-
outFile << "files: # Can be relative OR absolute local file "
309+
out_file << "files: # Can be relative OR absolute local file "
306310
"path\n";
307311
for (const auto& source : yaml_node_["files"]) {
308-
outFile << " - " << source << "\n";
312+
out_file << " - " << source << "\n";
309313
}
310314
}
311315

312-
outFile << "# END GENERAL GGUF METADATA\n";
313-
outFile << "\n";
316+
out_file << "# END GENERAL GGUF METADATA\n";
317+
out_file << "\n";
314318
// Write INFERENCE PARAMETERS
315-
outFile << "# BEGIN INFERENCE PARAMETERS\n";
316-
outFile << "# BEGIN REQUIRED\n";
319+
out_file << "# BEGIN INFERENCE PARAMETERS\n";
320+
out_file << "# BEGIN REQUIRED\n";
317321
if (yaml_node_["stop"] && yaml_node_["stop"].size()) {
318-
outFile << "stop: # tokenizer.ggml.eos_token_id\n";
322+
out_file << "stop: # tokenizer.ggml.eos_token_id\n";
319323
for (const auto& stop : yaml_node_["stop"]) {
320-
outFile << " - " << stop << "\n";
324+
out_file << " - " << stop << "\n";
321325
}
322326
}
323327

324-
outFile << "# END REQUIRED\n";
325-
outFile << "\n";
326-
outFile << "# BEGIN OPTIONAL\n";
327-
outFile << format_utils::writeKeyValue("size", yaml_node_["size"]);
328-
outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"],
328+
out_file << "# END REQUIRED\n";
329+
out_file << "\n";
330+
out_file << "# BEGIN OPTIONAL\n";
331+
out_file << format_utils::WriteKeyValue("size", yaml_node_["size"]);
332+
out_file << format_utils::WriteKeyValue("stream", yaml_node_["stream"],
329333
"Default true?");
330-
outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"],
334+
out_file << format_utils::WriteKeyValue("top_p", yaml_node_["top_p"],
331335
"Ranges: 0 to 1");
332-
outFile << format_utils::writeKeyValue(
336+
out_file << format_utils::WriteKeyValue(
333337
"temperature", yaml_node_["temperature"], "Ranges: 0 to 1");
334-
outFile << format_utils::writeKeyValue(
338+
out_file << format_utils::WriteKeyValue(
335339
"frequency_penalty", yaml_node_["frequency_penalty"], "Ranges: 0 to 1");
336-
outFile << format_utils::writeKeyValue(
340+
out_file << format_utils::WriteKeyValue(
337341
"presence_penalty", yaml_node_["presence_penalty"], "Ranges: 0 to 1");
338-
outFile << format_utils::writeKeyValue(
342+
out_file << format_utils::WriteKeyValue(
339343
"max_tokens", yaml_node_["max_tokens"],
340344
"Should be default to context length");
341-
outFile << format_utils::writeKeyValue("seed", yaml_node_["seed"]);
342-
outFile << format_utils::writeKeyValue("dynatemp_range",
345+
out_file << format_utils::WriteKeyValue("seed", yaml_node_["seed"]);
346+
out_file << format_utils::WriteKeyValue("dynatemp_range",
343347
yaml_node_["dynatemp_range"]);
344-
outFile << format_utils::writeKeyValue("dynatemp_exponent",
348+
out_file << format_utils::WriteKeyValue("dynatemp_exponent",
345349
yaml_node_["dynatemp_exponent"]);
346-
outFile << format_utils::writeKeyValue("top_k", yaml_node_["top_k"]);
347-
outFile << format_utils::writeKeyValue("min_p", yaml_node_["min_p"]);
348-
outFile << format_utils::writeKeyValue("tfs_z", yaml_node_["tfs_z"]);
349-
outFile << format_utils::writeKeyValue("typ_p", yaml_node_["typ_p"]);
350-
outFile << format_utils::writeKeyValue("repeat_last_n",
350+
out_file << format_utils::WriteKeyValue("top_k", yaml_node_["top_k"]);
351+
out_file << format_utils::WriteKeyValue("min_p", yaml_node_["min_p"]);
352+
out_file << format_utils::WriteKeyValue("tfs_z", yaml_node_["tfs_z"]);
353+
out_file << format_utils::WriteKeyValue("typ_p", yaml_node_["typ_p"]);
354+
out_file << format_utils::WriteKeyValue("repeat_last_n",
351355
yaml_node_["repeat_last_n"]);
352-
outFile << format_utils::writeKeyValue("repeat_penalty",
356+
out_file << format_utils::WriteKeyValue("repeat_penalty",
353357
yaml_node_["repeat_penalty"]);
354-
outFile << format_utils::writeKeyValue("mirostat", yaml_node_["mirostat"]);
355-
outFile << format_utils::writeKeyValue("mirostat_tau",
358+
out_file << format_utils::WriteKeyValue("mirostat", yaml_node_["mirostat"]);
359+
out_file << format_utils::WriteKeyValue("mirostat_tau",
356360
yaml_node_["mirostat_tau"]);
357-
outFile << format_utils::writeKeyValue("mirostat_eta",
361+
out_file << format_utils::WriteKeyValue("mirostat_eta",
358362
yaml_node_["mirostat_eta"]);
359-
outFile << format_utils::writeKeyValue("penalize_nl",
363+
out_file << format_utils::WriteKeyValue("penalize_nl",
360364
yaml_node_["penalize_nl"]);
361-
outFile << format_utils::writeKeyValue("ignore_eos",
365+
out_file << format_utils::WriteKeyValue("ignore_eos",
362366
yaml_node_["ignore_eos"]);
363-
outFile << format_utils::writeKeyValue("n_probs", yaml_node_["n_probs"]);
364-
outFile << format_utils::writeKeyValue("min_keep", yaml_node_["min_keep"]);
365-
outFile << format_utils::writeKeyValue("grammar", yaml_node_["grammar"]);
366-
outFile << "# END OPTIONAL\n";
367-
outFile << "# END INFERENCE PARAMETERS\n";
368-
outFile << "\n";
367+
out_file << format_utils::WriteKeyValue("n_probs", yaml_node_["n_probs"]);
368+
out_file << format_utils::WriteKeyValue("min_keep", yaml_node_["min_keep"]);
369+
out_file << format_utils::WriteKeyValue("grammar", yaml_node_["grammar"]);
370+
out_file << "# END OPTIONAL\n";
371+
out_file << "# END INFERENCE PARAMETERS\n";
372+
out_file << "\n";
369373
// Write MODEL LOAD PARAMETERS
370-
outFile << "# BEGIN MODEL LOAD PARAMETERS\n";
371-
outFile << "# BEGIN REQUIRED\n";
372-
outFile << format_utils::writeKeyValue("engine", yaml_node_["engine"],
374+
out_file << "# BEGIN MODEL LOAD PARAMETERS\n";
375+
out_file << "# BEGIN REQUIRED\n";
376+
out_file << format_utils::WriteKeyValue("engine", yaml_node_["engine"],
373377
"engine to run model");
374-
outFile << "prompt_template:";
375-
outFile << " " << yaml_node_["prompt_template"] << "\n";
376-
outFile << "# END REQUIRED\n";
377-
outFile << "\n";
378-
outFile << "# BEGIN OPTIONAL\n";
379-
outFile << format_utils::writeKeyValue(
378+
out_file << "prompt_template:";
379+
out_file << " " << yaml_node_["prompt_template"] << "\n";
380+
out_file << "# END REQUIRED\n";
381+
out_file << "\n";
382+
out_file << "# BEGIN OPTIONAL\n";
383+
out_file << format_utils::WriteKeyValue(
380384
"ctx_len", yaml_node_["ctx_len"],
381385
"llama.context_length | 0 or undefined = loaded from model");
382-
outFile << format_utils::writeKeyValue("n_parallel",
386+
out_file << format_utils::WriteKeyValue("n_parallel",
383387
yaml_node_["n_parallel"]);
384-
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
388+
out_file << format_utils::WriteKeyValue("cpu_threads",
389+
yaml_node_["cpu_threads"]);
390+
out_file << format_utils::WriteKeyValue("ngl", yaml_node_["ngl"],
385391
"Undefined = loaded from model");
386-
outFile << "# END OPTIONAL\n";
387-
outFile << "# END MODEL LOAD PARAMETERS\n";
392+
out_file << "# END OPTIONAL\n";
393+
out_file << "# END MODEL LOAD PARAMETERS\n";
388394

389-
outFile.close();
395+
out_file.close();
390396
} catch (const std::exception& e) {
391397
std::cerr << "Error writing to file: " << e.what() << std::endl;
392398
throw;

engine/test/components/test_format_utils.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,37 @@ TEST_F(FormatUtilsTest, WriteKeyValue) {
99
{
1010
YAML::Node node;
1111
std::string result =
12-
format_utils::writeKeyValue("key", node["does_not_exist"]);
12+
format_utils::WriteKeyValue("key", node["does_not_exist"]);
1313
EXPECT_EQ(result, "");
1414
}
1515

1616
{
1717
YAML::Node node = YAML::Load("value");
18-
std::string result = format_utils::writeKeyValue("key", node);
18+
std::string result = format_utils::WriteKeyValue("key", node);
1919
EXPECT_EQ(result, "key: value\n");
2020
}
2121

2222
{
2323
YAML::Node node = YAML::Load("3.14159");
24-
std::string result = format_utils::writeKeyValue("key", node);
24+
std::string result = format_utils::WriteKeyValue("key", node);
2525
EXPECT_EQ(result, "key: 3.14159\n");
2626
}
2727

2828
{
2929
YAML::Node node = YAML::Load("3.000000");
30-
std::string result = format_utils::writeKeyValue("key", node);
30+
std::string result = format_utils::WriteKeyValue("key", node);
3131
EXPECT_EQ(result, "key: 3\n");
3232
}
3333

3434
{
3535
YAML::Node node = YAML::Load("3.140000");
36-
std::string result = format_utils::writeKeyValue("key", node);
36+
std::string result = format_utils::WriteKeyValue("key", node);
3737
EXPECT_EQ(result, "key: 3.14\n");
3838
}
3939

4040
{
4141
YAML::Node node = YAML::Load("value");
42-
std::string result = format_utils::writeKeyValue("key", node, "comment");
42+
std::string result = format_utils::WriteKeyValue("key", node, "comment");
4343
EXPECT_EQ(result, "key: value # comment\n");
4444
}
4545
}

engine/test/components/test_yaml_handler.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ temperature: 0.7
6363
max_tokens: 100
6464
stream: true
6565
n_parallel: 2
66+
cpu_threads: 3
6667
stop:
6768
- "END"
6869
files:
@@ -84,6 +85,7 @@ n_parallel: 2
8485
EXPECT_EQ(config.max_tokens, 100);
8586
EXPECT_TRUE(config.stream);
8687
EXPECT_EQ(config.n_parallel, 2);
88+
EXPECT_EQ(config.cpu_threads, 3);
8789
EXPECT_EQ(config.stop.size(), 1);
8890
EXPECT_EQ(config.stop[0], "END");
8991
EXPECT_EQ(config.files.size(), 1);
@@ -104,6 +106,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
104106
new_config.max_tokens = 200;
105107
new_config.stream = false;
106108
new_config.n_parallel = 2;
109+
new_config.cpu_threads = 3;
107110
new_config.stop = {"STOP", "END"};
108111
new_config.files = {"updated_file1.gguf", "updated_file2.gguf"};
109112

@@ -120,6 +123,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
120123
EXPECT_EQ(config.max_tokens, 200);
121124
EXPECT_FALSE(config.stream);
122125
EXPECT_EQ(config.n_parallel, 2);
126+
EXPECT_EQ(config.cpu_threads, 3);
123127
EXPECT_EQ(config.stop.size(), 2);
124128
EXPECT_EQ(config.stop[0], "STOP");
125129
EXPECT_EQ(config.stop[1], "END");
@@ -140,6 +144,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
140144
new_config.max_tokens = 150;
141145
new_config.stream = true;
142146
new_config.n_parallel = 2;
147+
new_config.cpu_threads = 3;
143148
new_config.stop = {"HALT"};
144149
new_config.files = {"write_test_file.gguf"};
145150

@@ -164,6 +169,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
164169
EXPECT_EQ(read_config.max_tokens, 150);
165170
EXPECT_TRUE(read_config.stream);
166171
EXPECT_EQ(read_config.n_parallel, 2);
172+
EXPECT_EQ(read_config.cpu_threads, 3);
167173
EXPECT_EQ(read_config.stop.size(), 1);
168174
EXPECT_EQ(read_config.stop[0], "HALT");
169175
EXPECT_EQ(read_config.files.size(), 1);

engine/utils/format_utils.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ inline std::string print_float(const std::string& key, float value) {
4646
} else
4747
return "";
4848
};
49-
inline std::string writeKeyValue(const std::string& key,
49+
inline std::string WriteKeyValue(const std::string& key,
5050
const YAML::Node& value,
5151
const std::string& comment = "") {
52-
std::ostringstream outFile;
52+
std::ostringstream out_file;
5353
if (!value)
5454
return "";
55-
outFile << key << ": ";
55+
out_file << key << ": ";
5656

5757
// Check if the value is a float and round it to 6 decimal places
5858
if (value.IsScalar()) {
@@ -66,19 +66,19 @@ inline std::string writeKeyValue(const std::string& key,
6666
if (strValue.back() == '.') {
6767
strValue.pop_back();
6868
}
69-
outFile << strValue;
69+
out_file << strValue;
7070
} catch (const std::exception& e) {
71-
outFile << value; // If not a float, write as is
71+
out_file << value; // If not a float, write as is
7272
}
7373
} else {
74-
outFile << value;
74+
out_file << value;
7575
}
7676

7777
if (!comment.empty()) {
78-
outFile << " # " << comment;
78+
out_file << " # " << comment;
7979
}
80-
outFile << "\n";
81-
return outFile.str();
80+
out_file << "\n";
81+
return out_file.str();
8282
};
8383

8484
inline std::string BytesToHumanReadable(uint64_t bytes) {

0 commit comments

Comments
 (0)