Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 72 additions & 13 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_draft = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_alias = argv[i];
return true;
}
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -655,20 +671,20 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model_url = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_draft = argv[i];
params.hf_repo = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_alias = argv[i];
params.hf_file = argv[i];
return true;
}
if (arg == "--lora") {
Expand Down Expand Up @@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: %s)\n", params.model_url.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding\n");
printf(" draft model for speculative decoding (default: unused)\n");
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: unused)\n");
printf(" -hfr REPO, --hf-repo REPO\n");
printf(" Hugging Face model repository (default: unused)\n");
printf(" -hff FILE, --hf-file FILE\n");
printf(" Hugging Face model file (default: unused)\n");
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
Expand Down Expand Up @@ -1655,8 +1675,10 @@ void llama_batch_add(

#ifdef LLAMA_USE_CURL

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params) {
struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
Expand Down Expand Up @@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
return llama_load_model_from_file(path_model, params);
}

struct llama_model * llama_load_model_from_hf(
const char * repo,
const char * model,
const char * path_model,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//

std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += model;

return llama_load_model_from_url(model_url.c_str(), path_model, params);
}

#else

struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
struct llama_model_params /*params*/) {
struct llama_model * llama_load_model_from_url(
const char * /*model_url*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}

struct llama_model * llama_load_model_from_hf(
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}

#endif // LLAMA_USE_CURL

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

llama_model * model = nullptr;
if (!params.model_url.empty()) {

if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
} else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}

if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
Expand Down Expand Up @@ -1908,7 +1967,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(),
Expand Down
10 changes: 6 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ struct gpt_params {
struct llama_sampling_params sparams;

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_url = ""; // model url to download
std::string model_draft = ""; // draft model for speculative decoding
std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download
std::string hf_repo = ""; // HF repo
std::string hf_file = ""; // HF file
std::string prompt = "";
std::string prompt_file = ""; // store the external prompt file name
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
Expand Down Expand Up @@ -192,8 +194,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);

// Batch utils

Expand Down