diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index d64956b843851..68a5c6ad5306d 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -21,6 +21,7 @@ else() add_subdirectory(main) add_subdirectory(perplexity) add_subdirectory(quantize) + add_subdirectory(server-core) if (LLAMA_BUILD_SERVER) add_subdirectory(server) endif() diff --git a/tools/server-core/CMakeLists.txt b/tools/server-core/CMakeLists.txt new file mode 100644 index 0000000000000..8a2209614971e --- /dev/null +++ b/tools/server-core/CMakeLists.txt @@ -0,0 +1,37 @@ +# serve + +set(TARGET server-core) + +add_library(${TARGET} + server-core.cpp + server-utils.hpp +) + +target_include_directories(${TARGET} PUBLIC .) +target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) + +target_link_libraries(${TARGET} PRIVATE common mtmd) + +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +if (BUILD_SHARED_LIBS) + set_target_properties (${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(${TARGET} PRIVATE LLAMA_BUILD) + target_compile_definitions(${TARGET} PUBLIC LLAMA_SHARED) +endif() + +set(serve_PUBLIC_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/server-core.h + ${CMAKE_CURRENT_SOURCE_DIR}/server-core.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/server-utils.hpp + ) + +set_target_properties(${TARGET} + PROPERTIES + PUBLIC_HEADER "${serve_PUBLIC_HEADERS}") + +install(TARGETS ${TARGET} LIBRARY PUBLIC_HEADER) + +if (TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/tools/server-core/README.md b/tools/server-core/README.md new file mode 100644 index 0000000000000..a569ca3a050f5 --- /dev/null +++ b/tools/server-core/README.md @@ -0,0 +1,3 @@ +# LLaMA.cpp Server core + +Core of abstract server API for **llama.cpp**. diff --git a/tools/server-core/server-core.cpp b/tools/server-core/server-core.cpp new file mode 100644 index 0000000000000..8bff2bd99fe40 --- /dev/null +++ b/tools/server-core/server-core.cpp @@ -0,0 +1,3 @@ +#include "server-core.h" + +// placeholder, some implementation can be moved here diff --git a/tools/server-core/server-core.h b/tools/server-core/server-core.h new file mode 100644 index 0000000000000..c816b63c81ef2 --- /dev/null +++ b/tools/server-core/server-core.h @@ -0,0 +1,3 @@ +#pragma once + +// placeholder, extern API to appear here diff --git a/tools/server-core/server-core.hpp b/tools/server-core/server-core.hpp new file mode 100644 index 0000000000000..334e9b88ecf07 --- /dev/null +++ b/tools/server-core/server-core.hpp @@ -0,0 +1,3627 @@ +#include "chat.h" +#include "server-utils.hpp" + +#include "arg.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(std::move(trigger)); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + server_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + } else { + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + } + params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const { + json base = { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; + } +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto & diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + + if (timings.prompt_n >= 0) { + deltas.back().push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json & delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); + + if (prob_output.probs.size() > 0) { + deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json { + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + deltas[deltas.size() - 1].push_back({"timings", timings.to_json()}); + } + } + + return deltas; + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + +struct server_slot { + int id; + int id_task = -1; + + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + + // the index relative to completion multi-task request + size_t index = 0; + + struct slot_params params; + + slot_state state = SLOT_STATE_IDLE; + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_processed = 0; + + // input prompt tokens + server_tokens prompt_tokens; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + common_chat_msg chat_msg; + + server_tokens cache_tokens; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + stop_type stop; + + std::string stopping_word; + + // sampling + json json_schema; + + struct common_sampler * smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens = 0; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_tokens.clear(); + generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + } + + bool need_embd() const { + return server_task_type_need_embd(task_type); + } + + bool need_logits() const { + return server_task_type_need_logits(task_type); + } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + + bool can_batch_with(server_slot & other_slot) const { + return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (params.n_predict != -1) { + n_remaining = params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { + return state != SLOT_STATE_IDLE; + } + + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + const common_chat_msg & update_chat_msg(std::vector & diffs) { + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string & word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"params", params.to_json()}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; + } +}; + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + } + + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + +struct server_queue { + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + + // Add a new task to the end of the queue + int post(server_task && task, bool front = false) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; + } + + // multi-task version of post() + int post(std::vector && tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + + // Add a new task, but defer until one slot is available + void defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); + } + + // Get the next id for creating a new task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; + } + + // Register function to process a new task + void on_new_task(std::function callback) { + callback_new_task = std::move(callback); + } + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); + } + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); + } + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } + } + +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); + } +}; + +struct server_response { + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); + } + + void add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } + } + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); + } + + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } + } + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + + // single-task version of recv() + server_task_result_ptr recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); + } + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } + } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } +}; + +struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + json default_generation_settings_for_props; + + server_queue queue_tasks; + server_response queue_results; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + + if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params_base.speculative.cache_type_k; + params_dft.cache_type_v = params_base.speculative.cache_type_v; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + + return true; + } + + void init() { + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; + + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + } + + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + + slot.params.sampling = params_base.sampling; + slot.params.n_keep = params_base.n_keep; + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + default_generation_settings_for_props = slots[0].to_json(); + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ params_base.reasoning_budget != 0, + }; + } + + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; + float similarity = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { + continue; + } + + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + + // select the current slot if the criteria match + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = -1; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (!ret || slot.t_last_used <= t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); + } + } + + return ret; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); + + if (!are_lora_equal(slot.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = slot.params.lora; + } + + if (!slot.prompt_tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); + slot.params.n_predict = slot.n_predict; + } + + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); + } + + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, slot.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + void kv_cache_clear() { + SRV_DBG("%s", "clearing KV cache\n"); + + // clear the entire KV cache + llama_memory_clear(llama_get_memory(ctx), true); + clean_kv_cache = false; + } + + bool process_token(completion_token_output & result, server_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.params.stream) { + send_partial_response(slot, result); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + } + + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } + } + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, error, type); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); + + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; + + res->index = slot.index; + res->content = slot.generated_text; + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to create new task(s) and receive result(s) + // + + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } + + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + i--; // retry + continue; + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); + } + result_handler(results); + } + + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + size_t n_finished = 0; + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } + + if (result->is_stop()) { + if (++n_finished == id_tasks.size()) { + break; + } + } + } + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_selected_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!ensure_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); + slot->cache_tokens.clear(); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { + kv_cache_clear(); + } + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); + + // add generated tokens to cache + { + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); + } + + slot.n_past -= n_discard; + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + }; + + // frist, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + + slot.n_past += 1; + slot.cache_tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + auto & prompt_tokens = slot.prompt_tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + }*/ + + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + slot.release(); + send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + continue; + } + + if (!slot.can_split()) { + if (slot.n_prompt_tokens > n_ubatch) { + slot.release(); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + continue; + } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } + } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + const int n_left = slot.n_ctx - slot.params.n_keep; + + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + + const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); + llama_tokens new_tokens( + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); + + new_tokens.insert( + new_tokens.end(), + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); + + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + } + } else { + // if we don't cache the prompt, we have to remove the entire KV cache + slot.n_past = 0; + } + + if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + const auto n_swa = llama_model_n_swa(model); + if (pos_min > std::max(0, slot.n_past - n_swa)) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + slot.n_past = 0; + } + } + } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + + slot.n_past--; + } + + slot.n_prompt_tokens_processed = 0; + } + + if (!slot.can_split()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; + } + } + + // keep only the common part + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + + // there is no common part left + slot.n_past = 0; + } + + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); + + // check if we should process the image + if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + // process the image + int32_t new_n_past; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t n_pos = new_n_past - slot.n_past; + + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + // add the image chunk to cache + { + const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_pos; + slot.n_prompt_tokens_processed += n_pos; + } + + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + + // embedding requires all tokens in the batch to be output + const bool need_embd = server_task_type_need_embd(slot.task_type); + + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); + + slot.n_prompt_tokens_processed++; + slot.n_past++; + } + + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + } + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + + llama_set_embeddings(ctx, slot_batched->need_embd()); + } + + int32_t i_next = 0; + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + + metrics.on_decoded(slots); + + if (ret != 0) { + { + std::string err; + + if (n_batch == 1 && ret == 1) { + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } + + // TODO: handle ret == 2 (abort) when we start aborting + + if (!err.empty()) { + SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { + slot.release(); + send_error(slot, err); + } + break; + } + } + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch + } + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + + for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; + } + } + + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } +}; diff --git a/tools/server-core/server-utils.hpp b/tools/server-core/server-utils.hpp new file mode 100644 index 0000000000000..e44a6aae16d5c --- /dev/null +++ b/tools/server-core/server-utils.hpp @@ -0,0 +1,1321 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include "arg.h" // common_remote_get_content +#include "base64.hpp" +#include "mtmd.h" +#include "mtmd-helper.h" +#include "chat.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +using raw_buffer = std::vector; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); + return default_value; + } + } else { + return default_value; + } +} + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + +// +// tokenizer and input processing utils +// + +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; + } + return false; +} + +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } + + return prompt_tokens; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } + } + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// +// template utils +// + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + // Get EOS token - use SEP token as fallback if EOS is not available + llama_token eos_token = llama_vocab_eos(vocab); + if (eos_token == LLAMA_TOKEN_NULL) { + eos_token = llama_vocab_sep(vocab); + } + + result.reserve(doc.size() + query.size() + 4); + if (llama_vocab_get_add_bos(vocab)) { + result.push_back(llama_vocab_bos(vocab)); + } + result.insert(result.end(), query.begin(), query.end()); + if (llama_vocab_get_add_eos(vocab)) { + result.push_back(eos_token); + } + if (llama_vocab_get_add_sep(vocab)) { + result.push_back(llama_vocab_sep(vocab)); + } + result.insert(result.end(), doc.begin(), doc.end()); + if (llama_vocab_get_add_eos(vocab)) { + result.push_back(eos_token); + } + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +static inline raw_buffer base64_decode(const std::string & encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; + + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + raw_buffer ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); + } + + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } + + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +// +// random string / id +// + +static std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; + } + + return result; +} + +static std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} + +static std::string gen_tool_call_id() { + return random_string(); +} + +// +// other common utils +// + +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); + } + + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + + return out; +} + +// +// OAI utils +// + +// used by /completions endpoint +static json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + std::map chat_template_kwargs; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +static json oaicompat_chat_params_parse( + json & body, /* openai api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files) +{ + json llama_params; + + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + auto stream = json_value(body, "stream", false); + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + + if (!opt.use_jinja) { + if (has_tools) { + throw std::runtime_error("tools param requires --jinja flag"); + } + if (tool_choice != "auto") { + throw std::runtime_error("tool_choice param requires --jinja flag"); + } + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + + // Handle "response_format" field + if (body.contains("response_format")) { + json response_format = json_value(body, "response_format", json::object()); + std::string response_type = json_value(response_format, "type", std::string()); + if (response_type == "json_object") { + json_schema = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + auto schema_wrapper = json_value(response_format, "json_schema", json::object()); + json_schema = json_value(schema_wrapper, "schema", json::object()); + } else if (!response_type.empty() && response_type != "text") { + throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + } + } + + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json & messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + if (role != "assistant" && !msg.contains("content")) { + throw std::runtime_error("All non-assistant messages must contain 'content'"); + } + if (role == "assistant") { + if (!msg.contains("content") && !msg.contains("tool_calls")) { + throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + } + if (!msg.contains("content")) { + continue; // avoid errors with no content + } + } + json & content = msg.at("content"); + if (content.is_string() || content.is_null()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + if (type == "image_url") { + if (!opt.allow_image) { + throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + + json image_url = json_value(p, "image_url", json::object()); + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("image_url"); + + } else if (type == "input_audio") { + if (!opt.allow_audio) { + throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + + json input_audio = json_value(p, "input_audio", json::object()); + std::string data = json_value(input_audio, "data", std::string()); + std::string format = json_value(input_audio, "format", std::string()); + // while we also support flac, we don't allow it here so we matches the OAI spec + if (format != "wav" && format != "mp3") { + throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + } + auto decoded_data = base64_decode(data); // expected to be base64 encoded + out_files.push_back(decoded_data); + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("input_audio"); + + } else if (type != "text") { + throw std::runtime_error("unsupported content[].type"); + } + } + } + + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.reasoning_format = opt.reasoning_format; + inputs.enable_thinking = opt.enable_thinking; + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + if (body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["parse_tool_calls"] = true; + } + + // merge the template args provided from command line with the args provided in the user request + auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); + inputs.chat_template_kwargs = opt.chat_template_kwargs; + for (const auto & item : chat_template_kwargs_object.items()) { + inputs.chat_template_kwargs[item.key()] = item.value().dump(); + } + + // if the assistant message appears at the end of list, we do not add end-of-turn token + // for ex. this can be useful to modify the reasoning process in reasoning models + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; + common_chat_msg last_message; + if (prefill_assistant_message) { + last_message = inputs.messages.back(); + inputs.messages.pop_back(); + + /* sanity check, max one assistant message at the end of the list */ + if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ + throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + } + + /* TODO: test this properly */ + inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; + + if ( (!inputs.enable_thinking) || inputs.chat_template_kwargs.find("enable_thinking") != inputs.chat_template_kwargs.end()) { + throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + } + + inputs.add_generation_prompt = true; + } + + // Apply chat template to the list of messages + auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); + + /* Append assistant prefilled message */ + if (prefill_assistant_message) { + chat_params.prompt += last_message.content; + } + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + if (!chat_params.grammar.empty()) { + llama_params["grammar"] = chat_params.grammar; + } + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Handle "logprobs" field + // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future + if (json_value(body, "logprobs", false)) { + if (has_tools && stream) { + throw std::runtime_error("logprobs is not supported with tools + stream"); + } + llama_params["n_probs"] = json_value(body, "top_logprobs", 20); + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { + throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + } + + // Copy remaining properties to llama_params + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. + // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + + n_tokens += json_value(elem, "tokens_evaluated", 0); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"data", data} + }; + + return res; +} + +static json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts) { + json res; + if (is_tei_format) { + // TEI response format + res = json::array(); + bool return_text = json_value(request, "return_text", false); + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {"score", json_value(rank, "score", 0.0)}, + }; + if (return_text) { + elem["text"] = std::move(texts[index]); + } + res.push_back(elem); + } + } else { + // Jina response format + json results = json::array(); + int32_t n_tokens = 0; + for (const auto & rank : ranks) { + results.push_back(json{ + {"index", json_value(rank, "index", 0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } + + res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; + } + + return res; +} + +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + +static json format_tokenizer_response(const json & tokens) { + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} + +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + return cur; +} + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** position in tokens to the image chunk + std::unordered_map map_pos_to_media; + + // list of tokens + // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token + // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** + // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // pos 0 1 2 3 4 5 6 7 8 9 + // map_pos_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } + } + + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + + // for debugging + std::string str() const { + std::ostringstream oss; + oss << "tokens: "; + for (const auto & t : tokens) { + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image pos: "; + for (const auto & it : map_pos_to_media) { + oss << it.first << ", "; + } + return oss.str(); + } + + const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { + auto it = map_pos_to_media.find(pos); + if (it != map_pos_to_media.end()) { + return it->second; + } else { + throw std::runtime_error("Chunk not found"); + } + } + + void push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); + } + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const int n_pos = mtmd_input_chunk_get_n_pos(chunk); + llama_pos start_pos = tokens.size(); + for (int i = 0; i < n_pos; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_pos_to_media[start_pos] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); + } + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; + } + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; + } + + size_t size() const { + return tokens.size(); + } + + bool empty() const { + return tokens.empty(); + } + + void clear() { + tokens.clear(); + } + + void keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { + llama_pos pos = it->first; + if (pos >= (llama_pos)n) { + it = map_pos_to_media.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); + } + + std::string detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); + } + + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = tokens[i]; + auto & bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + GGML_ASSERT(has_mtmd); + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); + std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); + std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); + size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); + size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); + if (ai_id == bi_id && a_pos == b_pos) { + GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen + i += a_pos - 1; // will be +1 by the for loop + continue; + } else { + return i; + } + } else if (ai == bi) { + continue; + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + i += n_pos - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; + } + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + llama_pos n_past, + int32_t seq_id, + llama_pos & n_pos_out) { + auto & chunk = find_chunk(n_past); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past = n_past; + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + n_past, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_pos_out = n_past; + return result; + } + n_pos_out = new_n_past; + return 0; + } +}; diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index c2a56aaa7f27b..5101144563a10 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -34,8 +34,9 @@ add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) target_include_directories(${TARGET} PRIVATE ../llava) +target_include_directories(${TARGET} PRIVATE ../server-core) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE server-core common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d3f6271931f62..970b3e5739e4a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,16 +1,6 @@ -#include "chat.h" +#include "server-core.hpp" #include "utils.hpp" -#include "arg.h" -#include "common.h" -#include "json-schema-to-grammar.h" -#include "llama.h" -#include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" - // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -18,3620 +8,6 @@ #include "index.html.gz.hpp" #include "loading.html.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -constexpr int HTTP_POLLING_SECONDS = 1; - -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, -}; - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct slot_params { - bool stream = true; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters - - int64_t t_max_prompt_ms = -1; // TODO: implement - int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - - std::vector lora; - - std::vector antiprompt; - std::vector response_fields; - bool timings_per_token = false; - bool post_sampling_probs = false; - bool ignore_eos = false; - - struct common_params_sampling sampling; - struct common_params_speculative speculative; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - - json to_json() const { - std::vector samplers; - samplers.reserve(sampling.samplers.size()); - for (const auto & sampler : sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); - } - - auto grammar_triggers = json::array(); - for (const auto & trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(std::move(trigger)); - grammar_triggers.push_back(ct.to_json()); - } - - return json { - {"n_predict", n_predict}, // Server configured n_predict - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", sampling.dry_sequence_breakers}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"stop", antiprompt}, - {"max_tokens", n_predict}, // User configured n_predict - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"logit_bias", format_logit_bias(sampling.logit_bias)}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, - {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) - - server_task_type type; - - // used by SERVER_TASK_TYPE_CANCEL - int id_target = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - server_tokens prompt_tokens; - int id_selected_slot = -1; - - // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE - struct slot_action { - int slot_id; - std::string filename; - std::string filepath; - }; - slot_action slot_action; - - // used by SERVER_TASK_TYPE_METRICS - bool metrics_reset_bucket = false; - - // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; - - server_task(server_task_type type) : type(type) {} - - static slot_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - slot_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 0); - params.speculative.n_max = std::max(params.speculative.n_max, 0); - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = params_base.lora_adapters; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); - } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; - } - params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - params.sampling.logit_bias.clear(); - params.ignore_eos = json_value(data, "ignore_eos", false); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - - return params; - } - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - -struct result_timings { - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const { - json base = { - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; - } -}; - -struct server_task_result { - int id = -1; - int id_slot = -1; - virtual bool is_error() { - // only used by server_task_result_error - return false; - } - virtual bool is_stop() { - // only used by server_task_result_cmpl_* - return false; - } - virtual int get_index() { - return -1; - } - virtual json to_json() = 0; - virtual ~server_task_result() = default; -}; - -// using shared_ptr for polymorphism of server_task_result -using server_task_result_ptr = std::unique_ptr; - -inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - -struct completion_token_output { - llama_token tok; - float prob; - std::string text_to_send; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto & p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; - } - - static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto & p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; - } - - static float logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); - } - - static std::vector str_to_bytes(const std::string & str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; - } -}; - -struct server_task_result_cmpl_final : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - bool stream; - result_timings timings; - std::string prompt; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - stop_type stop = STOP_TYPE_NONE; - - bool post_sampling_probs; - std::vector probs_output; - std::vector response_fields; - - slot_params generation_params; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return true; // in stream mode, final responses are considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - json res = json { - {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? llama_tokens {} : tokens}, - {"id_slot", id_slot}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - {"generation_settings", generation_params.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - {"stop_type", stop_type_to_str(stop)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, - }; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json { - {"choices", json::array({ - json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json choice { - {"finish_reason", finish_reason}, - {"index", 0}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json { - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto & diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - - if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); - } - - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat(); - } - - return deltas; - } -}; - -struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - int32_t n_decoded; - int32_t n_prompt_tokens; - - bool post_sampling_probs; - completion_token_output prob_output; - result_timings timings; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return false; // in stream mode, partial responses are not considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - // non-OAI-compat JSON - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_slot}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({"timings", timings.to_json()}); - } - if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); - } - return res; - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json & delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); - - if (prob_output.probs.size() > 0) { - deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json { - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - deltas[deltas.size() - 1].push_back({"timings", timings.to_json()}); - } - } - - return deltas; - } -}; - -struct server_task_result_embd : server_task_result { - int index = 0; - std::vector> embedding; - - int32_t n_tokens; - - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? to_json_oaicompat() - : to_json_non_oaicompat(); - } - - json to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; - } - - json to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -struct server_task_result_rerank : server_task_result { - int index = 0; - float score = -1e6; - - int32_t n_tokens; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - -struct server_task_result_error : server_task_result { - int index = 0; - error_type err_type = ERROR_TYPE_SERVER; - std::string err_msg; - - virtual bool is_error() override { - return true; - } - - virtual json to_json() override { - return format_error_response(err_msg, err_type); - } -}; - -struct server_task_result_metrics : server_task_result { - int n_idle_slots; - int n_processing_slots; - int n_tasks_deferred; - int64_t t_start; - - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - // while we can also use std::vector this requires copying the slot object which can be quite messy - // therefore, we use json to temporarily store the slot.to_json() result - json slots_data = json::array(); - - virtual json to_json() override { - return json { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", n_tasks_deferred }, - { "t_start", t_start }, - - { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, - { "t_tokens_generation_total", t_tokens_generation_total }, - { "n_tokens_predicted_total", n_tokens_predicted_total }, - { "t_prompt_processing_total", t_prompt_processing_total }, - - { "n_prompt_tokens_processed", n_prompt_tokens_processed }, - { "t_prompt_processing", t_prompt_processing }, - { "n_tokens_predicted", n_tokens_predicted }, - { "t_tokens_generation", t_tokens_generation }, - - { "n_decode_total", n_decode_total }, - { "n_busy_slots_total", n_busy_slots_total }, - - { "slots", slots_data }, - }; - } -}; - -struct server_task_result_slot_save_load : server_task_result { - std::string filename; - bool is_save; // true = save, false = load - - size_t n_tokens; - size_t n_bytes; - double t_ms; - - virtual json to_json() override { - if (is_save) { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", n_tokens }, - { "n_written", n_bytes }, - { "timings", { - { "save_ms", t_ms } - }}, - }; - } else { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; - } - } -}; - -struct server_task_result_slot_erase : server_task_result { - size_t n_erased; - - virtual json to_json() override { - return json { - { "id_slot", id_slot }, - { "n_erased", n_erased }, - }; - } -}; - -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - -struct server_slot { - int id; - int id_task = -1; - - // only used for completion/embedding/infill/rerank - server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - - llama_batch batch_spec = {}; - - llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - common_speculative * spec = nullptr; - - std::vector lora; - - // the index relative to completion multi-task request - size_t index = 0; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; - - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; - int32_t n_prompt_tokens_processed = 0; - - // input prompt tokens - server_tokens prompt_tokens; - - size_t last_nl_pos = 0; - - std::string generated_text; - llama_tokens generated_tokens; - common_chat_msg chat_msg; - - server_tokens cache_tokens; - - std::vector generated_token_probs; - - bool has_next_token = true; - bool has_new_line = false; - bool truncated = false; - stop_type stop; - - std::string stopping_word; - - // sampling - json json_schema; - - struct common_sampler * smpl = nullptr; - - llama_token sampled; - - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - - // stats - size_t n_sent_text = 0; // number of sent text character - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - std::function callback_on_release; - - // Speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - void reset() { - SLT_DBG(*this, "%s", "\n"); - - n_prompt_tokens = 0; - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - task_type = SERVER_TASK_TYPE_COMPLETION; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - - generated_tokens.clear(); - generated_token_probs.clear(); - chat_msg = {}; - json_schema = json(); - generated_tool_call_ids.clear(); - - // clear speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - } - - bool need_embd() const { - return server_task_type_need_embd(task_type); - } - - bool need_logits() const { - return server_task_type_need_logits(task_type); - } - - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - - bool can_batch_with(server_slot & other_slot) const { - return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); - } - - bool has_budget(const common_params & global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool is_processing() const { - return state != SLOT_STATE_IDLE; - } - - bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; - } - - void add_token(const completion_token_output & token) { - if (!is_processing()) { - SLT_WRN(*this, "%s", "slot is not processing\n"); - return; - } - generated_token_probs.push_back(token); - } - - void release() { - if (is_processing()) { - SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); - - t_last_used = ggml_time_us(); - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - state = SLOT_STATE_IDLE; - callback_on_release(id); - } - } - - result_timings get_timings() const { - result_timings timings; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - const common_chat_msg & update_chat_msg(std::vector & diffs) { - auto previous_msg = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - return chat_msg; - } - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - size_t stop_pos = std::string::npos; - - for (const std::string & word : params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - // otherwise, partial stop - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stop = STOP_TYPE_WORD; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - const double t_gen = t_token_generation / n_decoded; - const double n_gen_second = 1e3 / t_token_generation * n_decoded; - - SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); - } - } - - json to_json() const { - return json { - {"id", id}, - {"id_task", id_task}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - {"params", params.to_json()}, - {"prompt", prompt_tokens.detokenize(ctx, true)}, - {"next_token", - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - {"stopping_word", stopping_word}, - } - }, - }; - } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - } - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_queue { - int id = 0; - bool running; - - // queues - std::deque queue_tasks; - std::deque queue_tasks_deferred; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_update_slots; - - // Add a new task to the end of the queue - int post(server_task && task, bool front = false) { - std::unique_lock lock(mutex_tasks); - GGML_ASSERT(task.id != -1); - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - const int task_id = task.id; - QUE_DBG("new task, id = %d, front = %d\n", task_id, front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - condition_tasks.notify_one(); - return task_id; - } - - // multi-task version of post() - int post(std::vector && tasks, bool front = false) { - std::unique_lock lock(mutex_tasks); - for (auto & task : tasks) { - if (task.id == -1) { - task.id = id++; - } - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - } - condition_tasks.notify_one(); - return 0; - } - - // Add a new task, but defer until one slot is available - void defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - QUE_DBG("defer task, id = %d\n", task.id); - queue_tasks_deferred.push_back(std::move(task)); - condition_tasks.notify_one(); - } - - // Get the next id for creating a new task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = std::move(callback); - } - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); - } - - // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task() { - std::unique_lock lock(mutex_tasks); - if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); - } - condition_tasks.notify_one(); - } - - // end the start_loop routine - void terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop() { - running = true; - - while (true) { - QUE_DBG("%s", "processing new tasks\n"); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.pop_front(); - lock.unlock(); - - QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); - } - - // all tasks in the current loop is processed, slots data is now ready - QUE_DBG("%s", "update slots\n"); - - callback_update_slots(); - - QUE_DBG("%s", "waiting for new tasks\n"); - { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - -private: - void cleanup_pending_task(int id_target) { - // no need lock because this is called exclusively by post() - auto rm_func = [id_target](const server_task & task) { - return task.id_target == id_target; - }; - queue_tasks.erase( - std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), - queue_tasks.end()); - queue_tasks_deferred.erase( - std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), - queue_tasks_deferred.end()); - } -}; - -struct server_response { - bool running = true; - - // for keeping track of all tasks waiting for the result - std::unordered_set waiting_task_ids; - - // the main result queue (using ptr for polymorphism) - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); - } - - void add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & task : tasks) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); - } - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); - } - - void remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & id_task : id_tasks) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - waiting_task_ids.erase(id_task); - } - } - - // This function blocks the thread until there is a response for one of the id_tasks - server_task_result_ptr recv(const std::unordered_set & id_tasks) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - return !queue_results.empty(); - }); - - for (size_t i = 0; i < queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // same as recv(), but have timeout in seconds - // if timeout is reached, nullptr is returned - server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { - while (true) { - std::unique_lock lock(mutex_results); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - - std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - if (cr_res == std::cv_status::timeout) { - return nullptr; - } - } - - // should never reach here - } - - // single-task version of recv() - server_task_result_ptr recv(int id_task) { - std::unordered_set id_tasks = {id_task}; - return recv(id_tasks); - } - - // Send a new result to a waiting id_task - void send(server_task_result_ptr && result) { - SRV_DBG("sending result for task id = %d\n", result->id); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - if (result->id == id_task) { - SRV_DBG("task id = %d pushed to result queue\n", result->id); - - queue_results.emplace_back(std::move(result)); - condition_results.notify_all(); - return; - } - } - } - - // terminate the waiting loop - void terminate() { - running = false; - condition_results.notify_all(); - } -}; - -struct server_context { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool clean_kv_cache = true; - bool add_bos_token = true; - bool has_eos_token = false; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - json default_generation_settings_for_props; - - server_queue queue_tasks; - server_response queue_results; - - server_metrics metrics; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - ~server_context() { - mtmd_free(mctx); - - // Clear any sampling context - for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - - common_speculative_free(slot.spec); - slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.path.c_str()); - - params_base = params; - - llama_init = common_init_from_params(params_base); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); - return false; - } - - vocab = llama_model_get_vocab(model); - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_vocab_get_add_bos(vocab); - has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; - - if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); - - auto params_dft = params_base; - - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; - - llama_init_dft = common_init_from_params(params_dft); - - model_dft = llama_init_dft.model.get(); - - if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); - return false; - } - - if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { - SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - - return false; - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); - } - - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - std::string & mmproj_path = params_base.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params_base.cpuparams.n_threads; - mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - } - - if (!params_base.speculative.model.path.empty()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - - if (!llama_memory_can_shift(llama_get_memory(ctx))) { - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); - } - } - - return true; - } - - void init() { - const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - - for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.n_predict = params_base.n_predict; - slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; - - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return; - } - - slot.spec = common_speculative_init(slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return; - } - } - - SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - - slot.params.sampling = params_base.sampling; - slot.params.n_keep = params_base.n_keep; - - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); - }; - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - default_generation_settings_for_props = slots[0].to_json(); - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - - metrics.init(); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ params_base.reasoning_budget != 0, - }; - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len = 0; - float similarity = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // skip the slot if it does not contains cached tokens - if (slot.cache_tokens.empty()) { - continue; - } - - // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - - // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); - - // select the current slot if the criteria match - if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; - similarity = cur_similarity; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = -1; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (!ret || slot.t_last_used <= t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); - } - } - - return ret; - } - - bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.prompt_tokens = std::move(task.prompt_tokens); - - if (!are_lora_equal(slot.params.lora, slot.lora)) { - // if lora is changed, we cannot reuse cached tokens - slot.cache_tokens.clear(); - slot.lora = slot.params.lora; - } - - if (!slot.prompt_tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); - slot.params.n_predict = slot.n_predict; - } - - if (slot.params.ignore_eos && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); - } - - { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, slot.params.sampling); - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); - } - - slot.state = SLOT_STATE_STARTED; - - SLT_INF(slot, "%s", "processing task\n"); - - return true; - } - - void kv_cache_clear() { - SRV_DBG("%s", "clearing KV cache\n"); - - // clear the entire KV cache - llama_memory_clear(llama_get_memory(ctx), true); - clean_kv_cache = false; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - slot.generated_text += token_str; - if (slot.params.return_tokens) { - slot.generated_tokens.push_back(result.tok); - } - slot.has_next_token = true; - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - // search stop word and delete it - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token(result); - if (slot.params.stream) { - send_partial_response(slot, result); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); - } - - if (slot.has_new_line) { - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.params.n_indent > 0) { - // check the current indentation - // TODO: improve by not doing it more than once for each new line - if (slot.last_nl_pos > 0) { - size_t pos = slot.last_nl_pos; - - int n_indent = 0; - while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { - n_indent++; - pos++; - } - - if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - // cut the last line - slot.generated_text.erase(pos, std::string::npos); - - SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); - } - } - - // find the next new line - { - const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); - - if (pos != std::string::npos) { - slot.last_nl_pos = pos + 1; - } - } - } - } - - // check if there is a new line in the generated text - if (result.text_to_send.find('\n') != std::string::npos) { - slot.has_new_line = true; - - // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); - } - } - - // if context shift is disabled, we stop when it reaches the context limit - if (slot.n_past >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); - } - - if (llama_vocab_is_eog(vocab, result.tok)) { - slot.stop = STOP_TYPE_EOS; - slot.has_next_token = false; - - SLT_DBG(slot, "%s", "stopped by EOS\n"); - } - - const auto n_ctx_train = llama_model_n_ctx_train(model); - - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; // stop prediction - - SLT_WRN(slot, - "n_predict (%d) is set for infinite generation. " - "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.params.n_predict, n_ctx_train); - } - - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { - size_t n_probs = slot.params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); - if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), - cur_p->data[i].p - }); - } - } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); - - // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - result.prob = cur[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), - cur[i].p - }); - } - } - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type); - } - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - - auto res = std::make_unique(); - res->id = id_task; - res->err_type = type; - res->err_msg = error; - - queue_results.send(std::move(res)); - } - - // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { - if (mctx) { - send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, const completion_token_output & tkn) { - auto res = std::make_unique(); - - res->id = slot.id_task; - res->index = slot.index; - res->content = tkn.text_to_send; - res->tokens = { tkn.tok }; - - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->post_sampling_probs = slot.params.post_sampling_probs; - - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - - slot.update_chat_msg(res->oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - res->prob_output = tkn; // copy the token probs - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { - res->timings = slot.get_timings(); - } - - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot & slot) { - auto res = std::make_unique(); - res->id = slot.id_task; - res->id_slot = slot.id; - - res->index = slot.index; - res->content = slot.generated_text; - res->tokens = std::move(slot.generated_tokens); - res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.params.response_fields); - - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->n_tokens_cached = slot.n_past; - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.params.post_sampling_probs; - - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - - size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } else { - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - } - - res->generation_params = slot.params; // copy the parameters - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; - res->oaicompat = slot.params.oaicompat; - - const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; - } - - // normalize only when there is pooling - // TODO: configurable - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - } else { - res->embedding.push_back({ embd, embd + n_embd }); - } - } - - SLT_DBG(slot, "%s", "sending embeddings\n"); - - queue_results.send(std::move(res)); - } - - void send_rerank(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->score = -1e6; - continue; - } - - res->score = embd[0]; - } - - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - - queue_results.send(std::move(res)); - } - - // - // Functions to create new task(s) and receive result(s) - // - - void cancel_tasks(const std::unordered_set & id_tasks) { - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - queue_tasks.post(std::move(cancel_tasks), true); - } - - // receive the results from task(s) - void receive_multi_results( - const std::unordered_set & id_tasks, - const std::function&)> & result_handler, - const std::function & error_handler, - const std::function & is_connection_closed) { - std::vector results(id_tasks.size()); - for (int i = 0; i < (int)id_tasks.size(); i++) { - server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - - if (is_connection_closed()) { - cancel_tasks(id_tasks); - return; - } - - if (result == nullptr) { - i--; // retry - continue; - } - - if (result->is_error()) { - error_handler(result->to_json()); - cancel_tasks(id_tasks); - return; - } - - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - const size_t idx = result->get_index(); - GGML_ASSERT(idx < results.size() && "index out of range"); - results[idx] = std::move(result); - } - result_handler(results); - } - - // receive the results from task(s), in stream mode - void receive_cmpl_results_stream( - const std::unordered_set & id_tasks, - const std::function & result_handler, - const std::function & error_handler, - const std::function & is_connection_closed) { - size_t n_finished = 0; - while (true) { - server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - - if (is_connection_closed()) { - cancel_tasks(id_tasks); - return; - } - - if (result == nullptr) { - continue; // retry - } - - if (result->is_error()) { - error_handler(result->to_json()); - cancel_tasks(id_tasks); - return; - } - - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (!result_handler(result)) { - cancel_tasks(id_tasks); - break; - } - - if (result->is_stop()) { - if (++n_finished == id_tasks.size()) { - break; - } - } - } - } - - // - // Functions to process the task - // - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - { - const int id_slot = task.id_selected_slot; - - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.id_task == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = slot.to_json(); - - if (slot.is_processing()) { - n_processing_slots++; - } else { - n_idle_slots++; - } - - slots_data.push_back(slot_data); - } - SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - - auto res = std::make_unique(); - res->id = task.id; - res->slots_data = std::move(slots_data); - res->n_idle_slots = n_idle_slots; - res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); - res->t_start = metrics.t_start; - - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res->t_prompt_processing_total = metrics.t_prompt_processing_total; - res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res->t_tokens_generation_total = metrics.t_tokens_generation_total; - - res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res->t_prompt_processing = metrics.t_prompt_processing; - res->n_tokens_predicted = metrics.n_tokens_predicted; - res->t_tokens_generation = metrics.t_tokens_generation; - - res->n_decode_total = metrics.n_decode_total; - res->n_busy_slots_total = metrics.n_busy_slots_total; - - if (task.metrics_reset_bucket) { - metrics.reset_bucket(); - } - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!ensure_no_mtmd(task.id)) { - break; - } - - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->cache_tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = true; - res->n_tokens = token_count; - res->n_bytes = nwrite; - res->t_ms = t_save_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!ensure_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - llama_tokens tokens; - tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); - if (nread == 0) { - slot->cache_tokens.clear(); // KV may already been invalidated? - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - tokens.resize(token_count); - slot->cache_tokens.clear(); - slot->cache_tokens.insert(tokens); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = false; - res->n_tokens = token_count; - res->n_bytes = nread; - res->t_ms = t_restore_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!ensure_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); - llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->cache_tokens.clear(); - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->n_erased = n_erased; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - params_base.lora_adapters = std::move(task.set_lora); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; - - } - } - - void update_slots() { - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.is_processing()) { - all_idle = false; - break; - } - } - - if (all_idle) { - SRV_INF("%s", "all slots are idle\n"); - if (clean_kv_cache) { - kv_cache_clear(); - } - - return; - } - } - - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { - if (!params_base.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - slot.release(); - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - continue; - } - - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - - SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); - - // add generated tokens to cache - { - llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - - new_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.cache_tokens.clear(); - slot.cache_tokens.insert(new_tokens); - } - - slot.n_past -= n_discard; - - slot.truncated = true; - } - } - - // start populating the batch for this iteration - common_batch_clear(batch); - - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); - }; - - // frist, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); - - slot.n_past += 1; - slot.cache_tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - // check if we can batch this slot with the previous one - if (slot.is_processing()) { - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - } - - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; - - // TODO: maybe move branch to outside of this loop in the future - if (slot.state == SLOT_STATE_STARTED) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); - slot.state = SLOT_STATE_PROCESSING_PROMPT; - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - - // print prompt tokens (for debugging) - /*if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - } - }*/ - - // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - - slot.release(); - slot.print_timings(); - send_final_response(slot); - continue; - } - - // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { - slot.release(); - send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); - continue; - } - - if (!slot.can_split()) { - if (slot.n_prompt_tokens > n_ubatch) { - slot.release(); - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - continue; - } - - if (slot.n_prompt_tokens > slot.n_ctx) { - slot.release(); - send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); - continue; - } - } else { - if (!params_base.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { - slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); - continue; - } - } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } - - if (slot.params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); - - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { - size_t head_c = slot.n_past; // cache - size_t head_p = slot.n_past; // current prompt - - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - - SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { - - size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { - - n_match++; - } - - if (n_match >= (size_t) params_base.n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); - slot.n_past++; - } - - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } - - SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); - } - } else { - // if we don't cache the prompt, we have to remove the entire KV cache - slot.n_past = 0; - } - - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); - GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); - } - - const auto n_swa = llama_model_n_swa(model); - if (pos_min > std::max(0, slot.n_past - n_swa)) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; - } - } - } - - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); - - slot.n_past--; - } - - slot.n_prompt_tokens_processed = 0; - } - - if (!slot.can_split()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { - continue; - } - } - - // keep only the common part - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - - // there is no common part left - slot.n_past = 0; - } - - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - - // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); - - // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { - // process the image - int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; - - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - slot.release(); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - continue; - } - - // add the image chunk to cache - { - const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.cache_tokens.push_back(chunk.get()); // copy - } - - slot.n_past += n_pos; - slot.n_prompt_tokens_processed += n_pos; - } - - // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - - // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task_type); - - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.cache_tokens.push_back(cur_tok); - - slot.n_prompt_tokens_processed++; - slot.n_past++; - } - - // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); - - // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { - slot.state = SLOT_STATE_DONE_PROMPT; - - GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); - - common_sampler_reset(slot.smpl); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); - } - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - - if (slot_batched) { - // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); - - llama_set_embeddings(ctx, slot_batched->need_embd()); - } - - int32_t i_next = 0; - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - for (auto & slot : slots) { - slot.release(); - send_error(slot, err); - } - break; - } - } - - // retry with half the batch size to try to find a free slot in the KV cache - n_batch /= 2; - - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - - continue; // continue loop of n_batch - } - - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); - - for (auto & slot : slots) { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task_type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots - } - - const int tok_idx = slot.i_batch - i; - - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); - - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - - const int64_t t_current = ggml_time_us(); - - if (slot.n_decoded == 1) { - slot.t_start_generation = t_current; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - continue; - } - } - - // do speculative decoding - for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { - continue; - } - - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; - - // note: n_past is not yet increased for the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - - slot.n_past += ids.size(); - slot.n_decoded += ids.size(); - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - break; - } - } - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); - } - } - - SRV_DBG("%s", "run slots completed\n"); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } -}; static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 2ef9a164513c3..f1103595e554d 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1,13 +1,3 @@ -#pragma once - -#include "common.h" -#include "log.h" -#include "llama.h" -#include "arg.h" // common_remote_get_content -#include "base64.hpp" -#include "mtmd.h" -#include "mtmd-helper.h" -#include "chat.h" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -15,506 +5,6 @@ #define CPPHTTPLIB_TCP_NODELAY true #include -#define JSON_ASSERT GGML_ASSERT -#include - -#include -#include -#include -#include -#include -#include - -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - -using json = nlohmann::ordered_json; - -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -using raw_buffer = std::vector; - -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); - return default_value; - } - } else { - return default_value; - } -} - -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; - - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} - server_grammar_trigger(const json & in) { - value.type = (common_grammar_trigger_type) in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token) in.at("token").get(); - } - } - - json to_json() const { - json out { - {"type", (int) value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) value.token; - } - return out; - } -}; - -// -// tokenizer and input processing utils -// - -static bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number_integer()) { - return false; - } - } - return true; - } - return false; -} - -// is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json & data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto & e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } - } - } - return false; -} - -// get value by path(key1 / key2) -static json json_get_nested_values(const std::vector & paths, const json & js) { - json result = json::object(); - - for (const std::string & path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string & k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; - } - } - return result; -} - -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); - } - - return prompt_tokens; -} - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] - */ -static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - result.push_back(json_prompt.get()); - } else if (json_prompt.is_array()) { - // array of prompts - result.reserve(json_prompt.size()); - for (const auto & p : json_prompt) { - if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { - result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); - } else if (json_is_array_of_numbers(p)) { - // array of tokens - result.push_back(p.get()); - } else { - throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); - } - } - } else { - throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; -} - -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string& text) { - size_t len = text.size(); - if (len == 0) return 0; - - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) return len - i; - } - } - - // If no cut-off multi-byte character is found, return full length - return len; -} - -// -// template utils -// - -// format rerank task: [BOS]query[EOS][SEP]doc[EOS] -static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { - llama_tokens result; - - // Get EOS token - use SEP token as fallback if EOS is not available - llama_token eos_token = llama_vocab_eos(vocab); - if (eos_token == LLAMA_TOKEN_NULL) { - eos_token = llama_vocab_sep(vocab); - } - - result.reserve(doc.size() + query.size() + 4); - if (llama_vocab_get_add_bos(vocab)) { - result.push_back(llama_vocab_bos(vocab)); - } - result.insert(result.end(), query.begin(), query.end()); - if (llama_vocab_get_add_eos(vocab)) { - result.push_back(eos_token); - } - if (llama_vocab_get_add_sep(vocab)) { - result.push_back(llama_vocab_sep(vocab)); - } - result.insert(result.end(), doc.begin(), doc.end()); - if (llama_vocab_get_add_eos(vocab)) { - result.push_back(eos_token); - } - - return result; -} - -// format infill task -static llama_tokens format_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt - ) { - // TODO: optimize this block by reducing memory allocations and movement - - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - llama_tokens extra_tokens; - extra_tokens.reserve(n_ctx); - - auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - - if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - - extra_tokens.push_back(llama_vocab_fim_rep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); - } - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - const std::string text = json_value(chunk, "text", std::string()); - const std::string filename = json_value(chunk, "filename", std::string("tmp")); - - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - - extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); - } - - const auto chunk_tokens = common_tokenize(vocab, text, false, false); - extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); - } - - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } - - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); - const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - - SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); - - tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); - - auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - - if (llama_vocab_get_add_bos(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); - } - - SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_vocab_fim_mid(vocab)); - - return embd_inp; -} - -// -// base64 utils (TODO: move to common in the future) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static inline raw_buffer base64_decode(const std::string & encoded_string) { - int i = 0; - int j = 0; - int in_ = 0; - - int in_len = encoded_string.size(); - - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; - - raw_buffer ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); - } - - i = 0; - } - } - - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } - - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; j < i - 1; j++) { - ret.push_back(char_array_3[j]); - } - } - - return ret; -} - -// -// random string / id -// - -static std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - - std::random_device rd; - std::mt19937 generator(rd()); - - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; - } - - return result; -} - -static std::string gen_chatcmplid() { - return "chatcmpl-" + random_string(); -} - -static std::string gen_tool_call_id() { - return random_string(); -} - -// -// other common utils -// - -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); - } - - return ret; -} - -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - - return out; -} - static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { const std::string str = std::string(event) + ": " + @@ -526,817 +16,6 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons return sink.write(str.c_str(), str.size()); } -// -// OAI utils -// - -// used by /completions endpoint -static json oaicompat_completion_params_parse(const json & body) { - json llama_params; - - if (!body.contains("prompt")) { - throw std::runtime_error("\"prompt\" is required"); - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - // Handle "n" field - int n_choices = json_value(body, "n", 1); - if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); - } - - // Handle "echo" field - if (json_value(body, "echo", false)) { - throw std::runtime_error("Only no echo is supported"); - } - - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "best_of", "suffix" }; - for (const auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - - // Copy remaining properties to llama_params - for (const auto & item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; -}; - -// used by /chat/completions endpoint -static json oaicompat_chat_params_parse( - json & body, /* openai api json semantics */ - const oaicompat_parser_options & opt, - std::vector & out_files) -{ - json llama_params; - - auto tools = json_value(body, "tools", json()); - auto has_tools = tools.is_array() && !tools.empty(); - auto stream = json_value(body, "stream", false); - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - - if (!opt.use_jinja) { - if (has_tools) { - throw std::runtime_error("tools param requires --jinja flag"); - } - if (tool_choice != "auto") { - throw std::runtime_error("tool_choice param requires --jinja flag"); - } - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - auto json_schema = json_value(body, "json_schema", json()); - auto grammar = json_value(body, "grammar", std::string()); - if (!json_schema.is_null() && !grammar.empty()) { - throw std::runtime_error("Cannot use both json_schema and grammar"); - } - - // Handle "response_format" field - if (body.contains("response_format")) { - json response_format = json_value(body, "response_format", json::object()); - std::string response_type = json_value(response_format, "type", std::string()); - if (response_type == "json_object") { - json_schema = json_value(response_format, "schema", json::object()); - } else if (response_type == "json_schema") { - auto schema_wrapper = json_value(response_format, "json_schema", json::object()); - json_schema = json_value(schema_wrapper, "schema", json::object()); - } else if (!response_type.empty() && response_type != "text") { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); - } - } - - // get input files - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - json & messages = body.at("messages"); - if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array"); - } - for (auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - if (role != "assistant" && !msg.contains("content")) { - throw std::runtime_error("All non-assistant messages must contain 'content'"); - } - if (role == "assistant") { - if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); - } - if (!msg.contains("content")) { - continue; // avoid errors with no content - } - } - json & content = msg.at("content"); - if (content.is_string() || content.is_null()) { - continue; - } - - if (!content.is_array()) { - throw std::runtime_error("Expected 'content' to be a string or an array"); - } - - for (auto & p : content) { - std::string type = json_value(p, "type", std::string()); - if (type == "image_url") { - if (!opt.allow_image) { - throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); - } - - json image_url = json_value(p, "image_url", json::object()); - std::string url = json_value(image_url, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %ld bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("image_url.url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("image_url"); - - } else if (type == "input_audio") { - if (!opt.allow_audio) { - throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); - } - - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); - std::string format = json_value(input_audio, "format", std::string()); - // while we also support flac, we don't allow it here so we matches the OAI spec - if (format != "wav" && format != "mp3") { - throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); - } - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("input_audio"); - - } else if (type != "text") { - throw std::runtime_error("unsupported content[].type"); - } - } - } - - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.reasoning_format = opt.reasoning_format; - inputs.enable_thinking = opt.enable_thinking; - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - if (body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - llama_params["parse_tool_calls"] = true; - } - - // merge the template args provided from command line with the args provided in the user request - auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); - inputs.chat_template_kwargs = opt.chat_template_kwargs; - for (const auto & item : chat_template_kwargs_object.items()) { - inputs.chat_template_kwargs[item.key()] = item.value().dump(); - } - - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); - } - - /* TODO: test this properly */ - inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; - - if ( (!inputs.enable_thinking) || inputs.chat_template_kwargs.find("enable_thinking") != inputs.chat_template_kwargs.end()) { - throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); - } - - inputs.add_generation_prompt = true; - } - - // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - chat_params.prompt += last_message.content; - } - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - if (!chat_params.grammar.empty()) { - llama_params["grammar"] = chat_params.grammar; - } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - - // Handle "n" field - int n_choices = json_value(body, "n", 1); - if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); - } - - // Handle "logprobs" field - // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future - if (json_value(body, "logprobs", false)) { - if (has_tools && stream) { - throw std::runtime_error("logprobs is not supported with tools + stream"); - } - llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::runtime_error("top_logprobs requires logprobs to be set to true"); - } - - // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. - // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto & item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto & elem : embeddings) { - json embedding_obj; - - if (use_base64) { - const auto& vec = json_value(elem, "embedding", json::array()).get>(); - const char* data_ptr = reinterpret_cast(vec.data()); - size_t data_size = vec.size() * sizeof(float); - embedding_obj = { - {"embedding", base64::encode(data_ptr, data_size)}, - {"index", i++}, - {"object", "embedding"}, - {"encoding_format", "base64"} - }; - } else { - embedding_obj = { - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - } - data.push_back(embedding_obj); - - n_tokens += json_value(elem, "tokens_evaluated", 0); - } - - json res = json { - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json { - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"data", data} - }; - - return res; -} - -static json format_response_rerank( - const json & request, - const json & ranks, - bool is_tei_format, - std::vector & texts) { - json res; - if (is_tei_format) { - // TEI response format - res = json::array(); - bool return_text = json_value(request, "return_text", false); - for (const auto & rank : ranks) { - int index = json_value(rank, "index", 0); - json elem = json{ - {"index", index}, - {"score", json_value(rank, "score", 0.0)}, - }; - if (return_text) { - elem["text"] = std::move(texts[index]); - } - res.push_back(elem); - } - } else { - // Jina response format - json results = json::array(); - int32_t n_tokens = 0; - for (const auto & rank : ranks) { - results.push_back(json{ - {"index", json_value(rank, "index", 0)}, - {"relevance_score", json_value(rank, "score", 0.0)}, - }); - - n_tokens += json_value(rank, "tokens_evaluated", 0); - } - - res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{ - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"results", results} - }; - } - - return res; -} - -static bool is_valid_utf8(const std::string & str) { - const unsigned char* bytes = reinterpret_cast(str.data()); - const unsigned char* end = bytes + str.length(); - - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || - (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } - - return true; -} - -static json format_tokenizer_response(const json & tokens) { - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(const std::string & content) { - return json { - {"content", content} - }; -} - -static json format_logit_bias(const std::vector & logit_bias) { - json data = json::array(); - for (const auto & lb : logit_bias) { - data.push_back(json{ - {"bias", lb.bias}, - {"token", lb.token}, - }); - } - return data; -} - -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - -static std::vector get_token_probabilities(llama_context * ctx, int idx) { - std::vector cur; - const auto * logits = llama_get_logits_ith(ctx, idx); - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - // sort tokens by logits - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - - // apply softmax - float max_l = cur[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; - } - - return cur; -} - -static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { - if (l1.size() != l2.size()) { - return false; - } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { - return false; - } - } - return true; -} - -// parse lora config from JSON request, returned a copy of lora_base with updated scale -static std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } - - // set value - for (const auto & entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - return lora; -} - -// -// utils for interacting with libmtmd -// (may need to refactor in near future) -// - -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** position in tokens to the image chunk - std::unordered_map map_pos_to_media; - - // list of tokens - // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token - // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** - // important: for models using mrope, an image can contain multiple tokens but will use only one **position** - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // pos 0 1 2 3 4 5 6 7 8 9 - // map_pos_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); - } - } - - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} - - // for debugging - std::string str() const { - std::ostringstream oss; - oss << "tokens: "; - for (const auto & t : tokens) { - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image pos: "; - for (const auto & it : map_pos_to_media) { - oss << it.first << ", "; - } - return oss.str(); - } - - const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { - auto it = map_pos_to_media.find(pos); - if (it != map_pos_to_media.end()) { - return it->second; - } else { - throw std::runtime_error("Chunk not found"); - } - } - - void push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); - } - - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk * chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const int n_pos = mtmd_input_chunk_get_n_pos(chunk); - llama_pos start_pos = tokens.size(); - for (int i = 0; i < n_pos; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } - } else { - GGML_ABORT("Invalid chunk type"); - } - } - - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t size() const { - return tokens.size(); - } - - bool empty() const { - return tokens.empty(); - } - - void clear() { - tokens.clear(); - } - - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - llama_token last_token = tokens[n - 1]; - // make sure we never remove tokens in the middle of an image - if (last_token == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { - llama_pos pos = it->first; - if (pos >= (llama_pos)n) { - it = map_pos_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); - } - - std::string detokenize(const llama_context * ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto & t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); - } - - size_t get_common_prefix(const server_tokens & b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); - for (size_t i = 0; i < max_idx; ++i) { - auto & ai = tokens[i]; - auto & bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop - continue; - } else { - return i; - } - } else if (ai == bi) { - continue; - } else { - return i; - } - } - return max_idx; // all tokens are equal - } - - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context * ctx) const { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - - for (size_t i = 0; i < tokens.size(); ++i) { - auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto & chunk = find_chunk(i); - size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); - i += n_pos - 1; // will be +1 by the for loop - } catch (const std::exception & e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; - } - - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - llama_pos n_past, - int32_t seq_id, - llama_pos & n_pos_out) { - auto & chunk = find_chunk(n_past); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past = n_past; - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - n_past, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_pos_out = n_past; - return result; - } - n_pos_out = new_n_past; - return 0; - } -}; - // Computes FNV-1a hash of the data static std::string fnv_hash(const uint8_t * data, size_t len) { const uint64_t fnv_prime = 0x100000001b3ULL;