Skip to content

Commit

Permalink
feat: ollama with openAI endpoint + gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
Robitx committed Mar 24, 2024
1 parent 561b621 commit bce3808
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 28 deletions.
11 changes: 3 additions & 8 deletions lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ local config = {
},
},
ollama = {
endpoint = "http://localhost:11434/api/chat",
endpoint = "http://localhost:11434/v1/chat/completions",
},
lmsudio = {
endpoint = "http://localhost:1234/v1/chat/completions",
Expand Down Expand Up @@ -128,7 +128,7 @@ local config = {
chat = true,
command = false,
-- string with model name or table with model name and parameters
model = { model = "gpt-4", temperature = 1.1, top_p = 1 },
model = { model = "dummy", temperature = 1.1, top_p = 1 },
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are a general AI assistant.\n\n"
.. "The user provided the additional info about how they would like you to respond:\n\n"
Expand All @@ -148,10 +148,7 @@ local config = {
-- string with model name or table with model name and parameters
model = {
model = "mistral:7b-instruct-v0.2-q4_K_M",
temperature = 1.97,
top_p = 1,
num_ctx = 8192,
min_p = 0.05,
},
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are a general AI assistant.",
Expand All @@ -167,7 +164,6 @@ local config = {
temperature = 0.97,
top_p = 1,
num_ctx = 8192,
min_p = 0.05,
},
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are a general AI assistant.",
Expand Down Expand Up @@ -219,7 +215,6 @@ local config = {
temperature = 1.9,
top_p = 1,
num_ctx = 8192,
min_p = 0.05,
},
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are an AI working as a code editor providing answers.\n\n"
Expand All @@ -232,7 +227,7 @@ local config = {
-- directory for storing chat files
chat_dir = vim.fn.stdpath("data"):gsub("/$", "") .. "/gp/chats",
-- chat user prompt prefix
chat_user_prefix = "🗨:",
chat_user_prefix = "💬:",
-- chat assistant prompt prefix (static string or a table {static, template})
-- first string has to be static, second string can contain template {{agent}}
-- just a static string is legacy and the [{{agent}}] element is added automatically
Expand Down
85 changes: 65 additions & 20 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1156,22 +1156,63 @@ M.prepare_payload = function(messages, model, default_model, provider)
}
end

if provider == "ollama" then
local options = {}
for k, v in pairs(model) do
if k ~= "provider" and k ~= "model" then
options[k] = v
end
end
options.temperature = math.max(0, math.min(2, options.temperature or 1))
options.top_p = math.max(0, math.min(1, options.top_p or 1))
return {
model = model.model,
stream = true,
messages = messages,
options = options,
}
end
if provider == "googleai" then
for i, message in ipairs(messages) do
if message.role == "system" then
messages[i].role = "user"
end
if message.role == "assistant" then
messages[i].role = "model"
end
if message.content then
messages[i].parts = {
{
text = message.content,
},
}
messages[i].content = nil
end
end
local i = 1
while i < #messages do
if messages[i].role == messages[i + 1].role then
table.insert(messages[i].parts, {
text = messages[i + 1].parts[1].text,
})
table.remove(messages, i + 1)
else
i = i + 1
end
end
local payload = {
contents = messages,
safetySettings = {
{
category = "HARM_CATEGORY_HARASSMENT",
threshold = "BLOCK_NONE",
},
{
category = "HARM_CATEGORY_HATE_SPEECH",
threshold = "BLOCK_NONE",
},
{
category = "HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold = "BLOCK_NONE",
},
{
category = "HARM_CATEGORY_DANGEROUS_CONTENT",
threshold = "BLOCK_NONE",
},
},
generationConfig = {
temperature = math.max(0, math.min(2, model.temperature or 1)),
maxOutputTokens = model.num_ctx or 8192,
topP = math.max(0, math.min(1, model.top_p or 1)),
topK = model.top_k or 100,
},
}
return payload
end

return {
model = model.model,
Expand Down Expand Up @@ -1274,10 +1315,9 @@ M.query = function(buf, provider, payload, handler, on_exit)
end
end

if provider == "ollama" and line:match("message") and line:match("content") then
line = vim.json.decode(line)
if line.message and line.message.content then
content = line.message.content
if qt.provider == "googleai" then
if line:match('"text":') then
content = vim.json.decode("{" .. line .. "}").text
end
end

Expand Down Expand Up @@ -1358,6 +1398,11 @@ M.query = function(buf, provider, payload, handler, on_exit)
}
end

if provider == "googleai" then
headers = {}
endpoint = M._H.template_replace(endpoint, "{{secret}}", bearer)
end

if provider == "azure" then
headers = {
"-H",
Expand Down

0 comments on commit bce3808

Please sign in to comment.