Skip to content

arg : add model catalog #13385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 98 additions & 155 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "log.h"
#include "sampling.h"
#include "chat.h"
#include "catalog.h"

// fix problem with std::min and std::max
#if defined(_WIN32)
Expand Down Expand Up @@ -608,15 +609,18 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
static struct common_hf_file_res common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & bearer_token,
const std::string & model_endpoint) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}

std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;

// headers
std::vector<std::string> headers;
Expand Down Expand Up @@ -715,7 +719,7 @@ static bool common_download_model(
return false;
}

static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) {
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return {};
}
Expand All @@ -742,15 +746,15 @@ struct handle_model_result {
static handle_model_result common_params_handle_model(
struct common_params_model & model,
const std::string & bearer_token,
const std::string & model_path_default) {
const std::string & model_endpoint) {
handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token);
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, model_endpoint);
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
exit(1); // built without CURL, error message already printed
}
Expand All @@ -766,7 +770,6 @@ static handle_model_result common_params_handle_model(
}
}

std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
Expand All @@ -784,8 +787,6 @@ static handle_model_result common_params_handle_model(
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}

} else if (model.path.empty()) {
model.path = model_path_default;
}
}

Expand Down Expand Up @@ -835,7 +836,6 @@ static std::string get_all_kv_cache_types() {
//

static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
std::string arg;
const std::string arg_prefix = "--";
common_params & params = ctx_arg.params;

Expand Down Expand Up @@ -875,16 +875,91 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};

// normalize args
std::string input_pos_arg;
std::vector<std::string> input_opt_args;
input_opt_args.reserve(argc - 1);
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";

std::string arg = argv[i];
if (arg_to_options.find(arg) == arg_to_options.end()) {
// if we don't have a match, check if this can be a positional argument
if (input_pos_arg.empty()) {
input_pos_arg = std::move(arg);
continue;
} else {
// if the positional argument is already set, we cannot have another one
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
}

// normalize the argument (only applied to optional args)
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
input_opt_args.emplace_back(arg);
}

// handle positional argument (we only support one positional argument)
// the logic is as follow:
// 1. we try to find the model name in the catalog
// 2. if not found, we check the prefix protocol://
// 3. if no protocol found, we assume it is a local file
{
bool is_handled = false;
// check catalog
for (auto & entry : model_catalog) {
if (input_pos_arg == entry.name) {
is_handled = true;
// check if the model support current example
bool is_supported = false;
for (auto & ex : entry.examples) {
if (ctx_arg.ex == ex) {
is_supported = true;
break;
}
}
if (is_supported) {
entry.handler(params);
} else {
LOG_ERR("error: model '%s' is not supported by this tool\n", entry.name);
exit(1);
}
break;
}
}
// check protocol
// for contributors: if you want to add a new protocol,
// please add make sure it support either /resolve/main or registry API
// see common_params_handle_model() to understand it is handled
// note: we don't support ollama because it usually contains their proprietary model (incompatible with llama.cpp)
if (!is_handled) {
const std::string & arg = input_pos_arg;
// check if it is a URL
if (string_starts_with(arg, "http://") || string_starts_with(arg, "https://")) {
params.model.url = arg;
} else if (string_starts_with(arg, "hf://")) {
// hugging face repo
params.model.hf_repo = arg.substr(5);
} else if (string_starts_with(arg, "hf-mirror://")) {
// hugging face mirror
params.custom_model_endpoint = "hf-mirror.com";
params.model.hf_repo = arg.substr(12);
} else if (string_starts_with(arg, "ms://")) {
// modelscope
params.custom_model_endpoint = "modelscope.cn";
params.model.hf_repo = arg.substr(5);
} else {
// assume it is a local file
params.model.path = arg;
}
}
}

// handle optional args
for (size_t i = 1; i < input_opt_args.size(); i++) {
const std::string & arg = input_opt_args[i];
auto opt = *arg_to_options[arg];
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
Expand Down Expand Up @@ -934,7 +1009,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context

// handle model and download
{
auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH);
std::string model_endpoint = params.get_model_endpoint();
auto res = common_params_handle_model(params.model, params.hf_token, model_endpoint);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -944,12 +1020,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// only download mmproj if the current example is using it
for (auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, "");
common_params_handle_model(params.mmproj, params.hf_token, model_endpoint);
break;
}
}
common_params_handle_model(params.speculative.model, params.hf_token, "");
common_params_handle_model(params.vocoder.model, params.hf_token, "");
common_params_handle_model(params.speculative.model, params.hf_token, model_endpoint);
common_params_handle_model(params.vocoder.model, params.hf_token, model_endpoint);
}

if (params.escape) {
Expand Down Expand Up @@ -985,6 +1061,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
));
}

if (params.model.path.empty()) {
throw std::invalid_argument(
"model path is empty\n"
"please specify a model file or use one from the catalog\n"
"use --catalog to see the list of available models\n");
}

return true;
}

Expand Down Expand Up @@ -3178,145 +3261,5 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_TTS}));

// model-specific
add_opt(common_arg(
{"--tts-oute-default"},
string_format("use default OuteTTS models (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF";
params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf";
params.vocoder.model.hf_repo = "ggml-org/WavTokenizer";
params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf";
}
).set_examples({LLAMA_EXAMPLE_TTS}));

add_opt(common_arg(
{"--embd-bge-small-en-default"},
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--embd-e5-small-en-default"},
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.model.hf_file = "e5-small-v2-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--embd-gte-small-default"},
string_format("use default gte-small model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.model.hf_file = "gte-small-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--fim-qwen-1.5b-default"},
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--fim-qwen-3b-default"},
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--fim-qwen-7b-default"},
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--fim-qwen-7b-spec"},
string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg(
{"--fim-qwen-14b-spec"},
string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

return ctx_arg;
}
Loading
Loading