Skip to content

arg : add model catalog #13385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented May 8, 2025

I'm drafting this PR to illustrate my idea from #10932 (comment)

This is be viewed as an improvement to the existing "preset" system. The main ideas are:

  1. Make a dedicated catalog.h file with proper guides for contributors
  2. Use the catalog model name as positional arg (on top of that, we also support protocols like hf://, hf-mirror:// and ms://)

For example:

llama-server fim-qwen-7b-spec
llama-server hf://ggml-org/gemma-3-4b-it-GGUF
llama-server ./models/local-file.gguf

# override params also work better now (both commands below work)
llama-server fim-qwen-7b-spec --temperature 0
llama-server --temperature 0 fim-qwen-7b-spec

WDYT @ggerganov ?

@slaren
Copy link
Member

slaren commented May 8, 2025

This should probably be in a configuration file that can be modified easily.

@ngxson
Copy link
Collaborator Author

ngxson commented May 8, 2025

This should probably be in a configuration file that can be modified easily.

I thought about this, but the main problems are:

  • There will be no static type checking (or we will need to develop that system)
  • We will need to distribute the config file separately, or we will need to find a way to "embed" it into the binary

@slaren
Copy link
Member

slaren commented May 8, 2025

Distributing the file would be good, because it would allow users to add their own presets. It would definitely be more work to implement this, but maybe it could be done without too much trouble with the json library?

@ngxson
Copy link
Collaborator Author

ngxson commented May 8, 2025

Yes I think it can be done without json, we can implement a subset of YAML or .ini spec which will be some simple string/regex matching. But the main concern is that we need to have a validator that replaces the static type check (which will be quite heavy, no matter which input format) - Do you have any idea how to do this in a simple way?

@slaren
Copy link
Member

slaren commented May 8, 2025

I was thinking of using the json library that we already have to load the data directly into a struct. It seems to have support for doing this in a fairly simple way with what they call arbitrary type conversion. I expect that it will be able to detect when the json schema doesn't match the struct.

@ngxson
Copy link
Collaborator Author

ngxson commented May 8, 2025

Ok I see. Indeed, I try not to depend too much on nlohmann::json because it can be quite complicated for debugging, also it may make it harder for us to switch to another (more performant) JSON framework in the future.

But I think we can base on the same idea of arbitrary type conversion as you said. Indeed, I make my own deserialize / serializer using macro in wllama based on the same principle. If we target to support only string, int and float (no null value), it can already be enough. Though, one important thing that I don't know is how to take advantage of cpp template to do that.

@slaren
Copy link
Member

slaren commented May 8, 2025

I was hoping to keep it very simple, writing a serializer/deserializer could be a lot of code. This works:

NLOHMANN_JSON_SERIALIZE_ENUM(enum llama_pooling_type, {
    {LLAMA_POOLING_TYPE_MEAN, "mean"},
    {LLAMA_POOLING_TYPE_CLS, "cls"},
    {LLAMA_POOLING_TYPE_LAST, "last"},
})

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(common_params_model,
    path,
    url,
    hf_repo,
    hf_file)

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(common_params,
    model,
    port,
    n_gpu_layers,
    n_ctx,
    pooling_type)

int main() {
    json j = R"(
        {
            "model": {
                "hf_file": "qwen2.5-coder-7b-q8_0.gguf",
                "hf_repo": "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"
            },
            "n_gpu_layers": 99,
            "port": 8091,
            "n_ctx": 2048,
            "pooling_type": "mean"
        }
        )"_json;

    // conversion: json -> common_params
    auto params = j.template get<common_params>();
}

Unfortunately this doesn't validate much and will ignore extra parameters or invalid enum values. It can be addressed with a bit of boilerplate.

#include "json.hpp"
#include "common.h"

using namespace nlohmann;

#define NLOHMANN_JSON_SERIALIZE_ENUM_STRICT(ENUM_TYPE, ...)                                            \
    template<typename BasicJsonType>                                                            \
    inline void to_json(BasicJsonType& j, const ENUM_TYPE& e)                                   \
    {                                                                                           \
        static_assert(std::is_enum<ENUM_TYPE>::value, #ENUM_TYPE " must be an enum!");          \
        static const std::pair<ENUM_TYPE, BasicJsonType> m[] = __VA_ARGS__;                     \
        auto it = std::find_if(std::begin(m), std::end(m),                                      \
                               [e](const std::pair<ENUM_TYPE, BasicJsonType>& ej_pair) -> bool  \
        {                                                                                       \
            return ej_pair.first == e;                                                          \
        });                                                                                     \
        j = ((it != std::end(m)) ? it : std::begin(m))->second;                                 \
    }                                                                                           \
    template<typename BasicJsonType>                                                            \
    inline void from_json(const BasicJsonType& j, ENUM_TYPE& e)                                 \
    {                                                                                           \
        static_assert(std::is_enum<ENUM_TYPE>::value, #ENUM_TYPE " must be an enum!");          \
        static const std::pair<ENUM_TYPE, BasicJsonType> m[] = __VA_ARGS__;                     \
        auto it = std::find_if(std::begin(m), std::end(m),                                      \
                               [&j](const std::pair<ENUM_TYPE, BasicJsonType>& ej_pair) -> bool \
        {                                                                                       \
            return ej_pair.second == j;                                                         \
        });                                                                                     \
        if (it != std::end(m)) {                                                                \
            e = it->first;                                                                      \
        } else {                                                                                \
            throw std::invalid_argument("Invalid enum value: " + j.dump());                     \
        }                                                                                       \
    }


NLOHMANN_JSON_SERIALIZE_ENUM_STRICT(enum llama_pooling_type, {
    {LLAMA_POOLING_TYPE_MEAN, "mean"},
    {LLAMA_POOLING_TYPE_CLS, "cls"},
    {LLAMA_POOLING_TYPE_LAST, "last"},
})

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(common_params_model,
    path,
    url,
    hf_repo,
    hf_file)

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(common_params,
    model,
    port,
    n_gpu_layers,
    n_ctx,
    pooling_type)


void check_for_unknown_fields_impl(const json& j, const json& expected) {
    for (auto& [key, value] : j.items()) {
        if (!expected.contains(key)) {
            throw std::runtime_error("Unknown field in JSON: " + key);
        }

        // Recursively check nested objects
        if (value.is_object() && expected[key].is_object()) {
            check_for_unknown_fields_impl(value, expected[key]);
        }
    }
}

template<typename T>
void check_for_unknown_fields(const json& j) {
    json expected = T{};
    check_for_unknown_fields_impl(j, expected);
}

int main() {
    common_params params;

    json j = R"(
        {
            "model": {
                "hf_file": "qwen2.5-coder-7b-q8_0.gguf",
                "hf_repo": "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"
            },
            "n_gpu_layers": 99,
            "port": 8091,
            "n_ctx": 2048,
            "pooling_type": "cls"
        }
        )"_json;

    // throws if there are unknown fields in the JSON
    check_for_unknown_fields<common_params>(j);

    // conversion: json -> common_params
    auto p2 = j.template get<common_params>();
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants