Skip to content

Commit

Permalink
feat: working copilot agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Robitx committed Jan 20, 2024
1 parent 12900cf commit e9afe8b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 19 deletions.
30 changes: 30 additions & 0 deletions lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ local config = {
.. "- Don't elide any code from your output if the answer requires coding.\n"
.. "- Take a deep breath; You've got this!\n",
},
{
provider = "copilot",
name = "ChatCopilot",
chat = true,
command = false,
-- string with model name or table with model name and parameters
model = { model = "gpt-4", temperature = 1.1, top_p = 1 },
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are a general AI assistant.\n\n"
.. "The user provided the additional info about how they would like you to respond:\n\n"
.. "- If you're unsure don't guess and say you don't know instead.\n"
.. "- Ask question if you need clarification to provide better answer.\n"
.. "- Think deeply and carefully from first principles step by step.\n"
.. "- Zoom out first to see the big picture and then zoom in to details.\n"
.. "- Use Socratic method to improve your thinking and coding skills.\n"
.. "- Don't elide any code from your output if the answer requires coding.\n"
.. "- Take a deep breath; You've got this!\n",
},
{
provider = "openai",
name = "CodeGPT4",
Expand All @@ -125,6 +143,18 @@ local config = {
.. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n"
.. "START AND END YOUR ANSWER WITH:\n\n```",
},
{
provider = "copilot",
name = "CodeCopilot",
chat = false,
command = true,
-- string with the Copilot engine name or table with engine name and parameters if applicable
model = { model = "gpt-4", temperature = 0.8, top_p = 1, n = 1 },
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = "You are an AI working as a code editor.\n\n"
.. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n"
.. "START AND END YOUR ANSWER WITH:\n\n```",
},
},

-- directory for storing chat files
Expand Down
93 changes: 74 additions & 19 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,10 @@ function M.refresh_copilot_bearer()
end
local secret = M.providers.copilot.secret

if type(secret) == "table" then
return
end

local bearer = M._state.copilot_bearer or {}
if bearer.token and bearer.expires_at and bearer.expires_at > os.time() then
return
Expand Down Expand Up @@ -692,7 +696,7 @@ function M.refresh_copilot_bearer()

M._H.process(nil, "curl", curl_params, function(code, signal, stdout, stderr)
if code ~= 0 then
M.error(string.format("Copilot bearer resolve exited: %d, %d", code, signal))
M.error(string.format("Copilot bearer resolve exited: %d, %d", code, signal, stderr))
return
end

Expand Down Expand Up @@ -891,13 +895,15 @@ end

---@provider string # provider name
function M.resolve_secret(provider)
local post_process = function(name)
local p = M.providers[name]
local post_process = function()
local p = M.providers[provider]
if p.secret and type(p.secret) == "string" then
p.secret = p.secret:gsub("^%s*(.-)%s*$", "%1")
end

M.refresh_copilot_bearer()
if provider == "copilot" then
M.refresh_copilot_bearer()
end
end

local secret = M.providers[provider].secret
Expand All @@ -922,7 +928,7 @@ function M.resolve_secret(provider)
return
end
M.providers[provider].secret = content
post_process(provider)
post_process()
else
M.warning(
"config.providers."
Expand All @@ -941,10 +947,11 @@ function M.resolve_secret(provider)
end
end)
else
post_process(provider)
post_process()
end
end

--TODO: obsolete
M.valid_api_key = function()
local api_key = M.config.openai_api_key

Expand Down Expand Up @@ -981,7 +988,11 @@ M.refresh_state = function()
M._state.image_agent = M._image_agents[1]
end

M._state.copilot_bearer = M._state.copilot_bearer or state.copilot_bearer or nil
local bearer = M._state.copilot_bearer or state.copilot_bearer or nil
if bearer and bearer.expires_at and bearer.expires_at < os.time() then
bearer = nil
end
M._state.copilot_bearer = bearer

M.table_to_file(M._state, state_file)

Expand Down Expand Up @@ -1051,7 +1062,16 @@ M.prepare_commands = function()
template = M.config.template_prepend
end
end
M.Prompt(params, target, agent.cmd_prefix, agent.model, template, agent.system_prompt, whisper)
M.Prompt(
params,
target,
agent.cmd_prefix,
agent.model,
template,
agent.system_prompt,
whisper,
agent.provider
)
end

M.cmd[command] = function(params)
Expand Down Expand Up @@ -1134,10 +1154,11 @@ end

-- gpt query
---@param buf number | nil # buffer number
---@param provider string # provider name
---@param payload table # payload for openai api
---@param handler function # response handler
---@param on_exit function | nil # optional on_exit handler
M.query = function(buf, payload, handler, on_exit)
M.query = function(buf, provider, payload, handler, on_exit)
-- make sure handler is a function
if type(handler) ~= "function" then
M.error(
Expand All @@ -1154,6 +1175,7 @@ M.query = function(buf, payload, handler, on_exit)
M._queries[qid] = {
timestamp = os.time(),
buf = buf,
provider = provider,
payload = payload,
handler = handler,
on_exit = on_exit,
Expand Down Expand Up @@ -1183,12 +1205,14 @@ M.query = function(buf, payload, handler, on_exit)
qt.raw_response = qt.raw_response .. line .. "\n"
end
line = line:gsub("^data: ", "")
if line:match("chat%.completion%.chunk") then
if line:match("choices") and line:match("delta") and line:match("content") then
line = vim.json.decode(line)
local content = line.choices[1].delta.content
if content ~= nil then
qt.response = qt.response .. content
handler(qid, content)
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then
local content = line.choices[1].delta.content
if content and type(content) == "string" then
qt.response = qt.response .. content
handler(qid, content)
end
end
end
end
Expand Down Expand Up @@ -1239,7 +1263,17 @@ M.query = function(buf, payload, handler, on_exit)
end

-- try to replace model in endpoint (for azure)
local endpoint = M._H.template_replace(M.config.openai_api_endpoint, "{{model}}", payload.model)
local endpoint = M._H.template_replace(M.providers[provider].endpoint, "{{model}}", payload.model)
local bearer = M.providers[provider].secret
local headers = {}
if provider == "copilot" and M._state.copilot_bearer then
---@diagnostic disable-next-line: undefined-field
bearer = M._state.copilot_bearer.token or ""
headers = {
"-H",
"editor-version: vscode/1.85.1",
}
end

local curl_params = vim.deepcopy(M.config.curl_params or {})
local args = {
Expand All @@ -1250,7 +1284,7 @@ M.query = function(buf, payload, handler, on_exit)
"Content-Type: application/json",
-- api-key is for azure, authorization is for openai
"-H",
"Authorization: Bearer " .. M.config.openai_api_key,
"Authorization: Bearer " .. bearer,
"-H",
"api-key: " .. M.config.openai_api_key,
"-d",
Expand All @@ -1262,6 +1296,10 @@ M.query = function(buf, payload, handler, on_exit)
table.insert(curl_params, arg)
end

for _, header in ipairs(headers) do
table.insert(curl_params, header)
end

M._H.process(buf, "curl", curl_params, nil, out_reader(), nil)
end

Expand Down Expand Up @@ -2032,6 +2070,7 @@ M.chat_respond = function(params)
-- call the model and write response
M.query(
buf,
agent.provider,
M.prepare_payload(messages, headers.model, agent.model),
M.create_handler(buf, win, M._H.last_content_line(buf), true, "", not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
Expand Down Expand Up @@ -2074,6 +2113,7 @@ M.chat_respond = function(params)
-- call the model
M.query(
nil,
agent.provider,
M.prepare_payload(messages, nil, M.config.chat_topic_gen_model),
topic_handler,
vim.schedule_wrap(function()
Expand Down Expand Up @@ -2501,7 +2541,14 @@ M.get_command_agent = function()
local name = M._state.command_agent
local model = M.agents[name].model
local system_prompt = M.agents[name].system_prompt
return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt }
local provider = M.agents[name].provider
return {
cmd_prefix = cmd_prefix,
name = name,
model = model,
system_prompt = system_prompt,
provider = provider,
}
end

---@return table # { cmd_prefix, name, model, system_prompt }
Expand All @@ -2511,7 +2558,14 @@ M.get_chat_agent = function()
local name = M._state.chat_agent
local model = M.agents[name].model
local system_prompt = M.agents[name].system_prompt
return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt }
local provider = M.agents[name].provider
return {
cmd_prefix = cmd_prefix,
name = name,
model = model,
system_prompt = system_prompt,
provider = provider,
}
end

M.cmd.Context = function(params)
Expand Down Expand Up @@ -2556,7 +2610,7 @@ M.cmd.Context = function(params)
M._H.feedkeys("G", "x")
end

M.Prompt = function(params, target, prompt, model, template, system_template, whisper)
M.Prompt = function(params, target, prompt, model, template, system_template, whisper, provider)
-- enew, new, vnew, tabnew should be resolved into table
if type(target) == "function" then
target = target()
Expand Down Expand Up @@ -2824,6 +2878,7 @@ M.Prompt = function(params, target, prompt, model, template, system_template, wh
local agent = M.get_command_agent()
M.query(
buf,
provider,
M.prepare_payload(messages, model, agent.model),
handler,
vim.schedule_wrap(function(qid)
Expand Down

0 comments on commit e9afe8b

Please sign in to comment.