diff --git a/lua/gp/config.lua b/lua/gp/config.lua index b8cb596..068c837 100644 --- a/lua/gp/config.lua +++ b/lua/gp/config.lua @@ -43,7 +43,7 @@ local config = { }, }, ollama = { - endpoint = "http://localhost:11434/api/chat", + endpoint = "http://localhost:11434/v1/chat/completions", }, lmsudio = { endpoint = "http://localhost:1234/v1/chat/completions", @@ -128,7 +128,7 @@ local config = { chat = true, command = false, -- string with model name or table with model name and parameters - model = { model = "gpt-4", temperature = 1.1, top_p = 1 }, + model = { model = "dummy", temperature = 1.1, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = "You are a general AI assistant.\n\n" .. "The user provided the additional info about how they would like you to respond:\n\n" @@ -148,10 +148,7 @@ local config = { -- string with model name or table with model name and parameters model = { model = "mistral:7b-instruct-v0.2-q4_K_M", - temperature = 1.97, - top_p = 1, num_ctx = 8192, - min_p = 0.05, }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = "You are a general AI assistant.", @@ -167,7 +164,6 @@ local config = { temperature = 0.97, top_p = 1, num_ctx = 8192, - min_p = 0.05, }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = "You are a general AI assistant.", @@ -219,7 +215,6 @@ local config = { temperature = 1.9, top_p = 1, num_ctx = 8192, - min_p = 0.05, }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = "You are an AI working as a code editor providing answers.\n\n" @@ -232,7 +227,7 @@ local config = { -- directory for storing chat files chat_dir = vim.fn.stdpath("data"):gsub("/$", "") .. "/gp/chats", -- chat user prompt prefix - chat_user_prefix = "🗨:", + chat_user_prefix = "💬:", -- chat assistant prompt prefix (static string or a table {static, template}) -- first string has to be static, second string can contain template {{agent}} -- just a static string is legacy and the [{{agent}}] element is added automatically diff --git a/lua/gp/init.lua b/lua/gp/init.lua index c2dcc90..d2f419f 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -1156,22 +1156,63 @@ M.prepare_payload = function(messages, model, default_model, provider) } end - if provider == "ollama" then - local options = {} - for k, v in pairs(model) do - if k ~= "provider" and k ~= "model" then - options[k] = v - end - end - options.temperature = math.max(0, math.min(2, options.temperature or 1)) - options.top_p = math.max(0, math.min(1, options.top_p or 1)) - return { - model = model.model, - stream = true, - messages = messages, - options = options, - } - end + if provider == "googleai" then + for i, message in ipairs(messages) do + if message.role == "system" then + messages[i].role = "user" + end + if message.role == "assistant" then + messages[i].role = "model" + end + if message.content then + messages[i].parts = { + { + text = message.content, + }, + } + messages[i].content = nil + end + end + local i = 1 + while i < #messages do + if messages[i].role == messages[i + 1].role then + table.insert(messages[i].parts, { + text = messages[i + 1].parts[1].text, + }) + table.remove(messages, i + 1) + else + i = i + 1 + end + end + local payload = { + contents = messages, + safetySettings = { + { + category = "HARM_CATEGORY_HARASSMENT", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_HATE_SPEECH", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_SEXUALLY_EXPLICIT", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_DANGEROUS_CONTENT", + threshold = "BLOCK_NONE", + }, + }, + generationConfig = { + temperature = math.max(0, math.min(2, model.temperature or 1)), + maxOutputTokens = model.num_ctx or 8192, + topP = math.max(0, math.min(1, model.top_p or 1)), + topK = model.top_k or 100, + }, + } + return payload + end return { model = model.model, @@ -1274,10 +1315,9 @@ M.query = function(buf, provider, payload, handler, on_exit) end end - if provider == "ollama" and line:match("message") and line:match("content") then - line = vim.json.decode(line) - if line.message and line.message.content then - content = line.message.content + if qt.provider == "googleai" then + if line:match('"text":') then + content = vim.json.decode("{" .. line .. "}").text end end @@ -1358,6 +1398,11 @@ M.query = function(buf, provider, payload, handler, on_exit) } end + if provider == "googleai" then + headers = {} + endpoint = M._H.template_replace(endpoint, "{{secret}}", bearer) + end + if provider == "azure" then headers = { "-H",