From e9afe8b5054435be6b8897ccb94ae74831fd2d38 Mon Sep 17 00:00:00 2001 From: "Schmidt, Tibor" Date: Sat, 20 Jan 2024 23:53:16 +0100 Subject: [PATCH] feat: working copilot agents --- lua/gp/config.lua | 30 +++++++++++++++ lua/gp/init.lua | 93 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/lua/gp/config.lua b/lua/gp/config.lua index 4d01471..c510a3a 100644 --- a/lua/gp/config.lua +++ b/lua/gp/config.lua @@ -101,6 +101,24 @@ local config = { .. "- Don't elide any code from your output if the answer requires coding.\n" .. "- Take a deep breath; You've got this!\n", }, + { + provider = "copilot", + name = "ChatCopilot", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { model = "gpt-4", temperature = 1.1, top_p = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = "You are a general AI assistant.\n\n" + .. "The user provided the additional info about how they would like you to respond:\n\n" + .. "- If you're unsure don't guess and say you don't know instead.\n" + .. "- Ask question if you need clarification to provide better answer.\n" + .. "- Think deeply and carefully from first principles step by step.\n" + .. "- Zoom out first to see the big picture and then zoom in to details.\n" + .. "- Use Socratic method to improve your thinking and coding skills.\n" + .. "- Don't elide any code from your output if the answer requires coding.\n" + .. "- Take a deep breath; You've got this!\n", + }, { provider = "openai", name = "CodeGPT4", @@ -125,6 +143,18 @@ local config = { .. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n" .. "START AND END YOUR ANSWER WITH:\n\n```", }, + { + provider = "copilot", + name = "CodeCopilot", + chat = false, + command = true, + -- string with the Copilot engine name or table with engine name and parameters if applicable + model = { model = "gpt-4", temperature = 0.8, top_p = 1, n = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = "You are an AI working as a code editor.\n\n" + .. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n" + .. "START AND END YOUR ANSWER WITH:\n\n```", + }, }, -- directory for storing chat files diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 57e8b7c..338c49a 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -662,6 +662,10 @@ function M.refresh_copilot_bearer() end local secret = M.providers.copilot.secret + if type(secret) == "table" then + return + end + local bearer = M._state.copilot_bearer or {} if bearer.token and bearer.expires_at and bearer.expires_at > os.time() then return @@ -692,7 +696,7 @@ function M.refresh_copilot_bearer() M._H.process(nil, "curl", curl_params, function(code, signal, stdout, stderr) if code ~= 0 then - M.error(string.format("Copilot bearer resolve exited: %d, %d", code, signal)) + M.error(string.format("Copilot bearer resolve exited: %d, %d", code, signal, stderr)) return end @@ -891,13 +895,15 @@ end ---@provider string # provider name function M.resolve_secret(provider) - local post_process = function(name) - local p = M.providers[name] + local post_process = function() + local p = M.providers[provider] if p.secret and type(p.secret) == "string" then p.secret = p.secret:gsub("^%s*(.-)%s*$", "%1") end - M.refresh_copilot_bearer() + if provider == "copilot" then + M.refresh_copilot_bearer() + end end local secret = M.providers[provider].secret @@ -922,7 +928,7 @@ function M.resolve_secret(provider) return end M.providers[provider].secret = content - post_process(provider) + post_process() else M.warning( "config.providers." @@ -941,10 +947,11 @@ function M.resolve_secret(provider) end end) else - post_process(provider) + post_process() end end +--TODO: obsolete M.valid_api_key = function() local api_key = M.config.openai_api_key @@ -981,7 +988,11 @@ M.refresh_state = function() M._state.image_agent = M._image_agents[1] end - M._state.copilot_bearer = M._state.copilot_bearer or state.copilot_bearer or nil + local bearer = M._state.copilot_bearer or state.copilot_bearer or nil + if bearer and bearer.expires_at and bearer.expires_at < os.time() then + bearer = nil + end + M._state.copilot_bearer = bearer M.table_to_file(M._state, state_file) @@ -1051,7 +1062,16 @@ M.prepare_commands = function() template = M.config.template_prepend end end - M.Prompt(params, target, agent.cmd_prefix, agent.model, template, agent.system_prompt, whisper) + M.Prompt( + params, + target, + agent.cmd_prefix, + agent.model, + template, + agent.system_prompt, + whisper, + agent.provider + ) end M.cmd[command] = function(params) @@ -1134,10 +1154,11 @@ end -- gpt query ---@param buf number | nil # buffer number +---@param provider string # provider name ---@param payload table # payload for openai api ---@param handler function # response handler ---@param on_exit function | nil # optional on_exit handler -M.query = function(buf, payload, handler, on_exit) +M.query = function(buf, provider, payload, handler, on_exit) -- make sure handler is a function if type(handler) ~= "function" then M.error( @@ -1154,6 +1175,7 @@ M.query = function(buf, payload, handler, on_exit) M._queries[qid] = { timestamp = os.time(), buf = buf, + provider = provider, payload = payload, handler = handler, on_exit = on_exit, @@ -1183,12 +1205,14 @@ M.query = function(buf, payload, handler, on_exit) qt.raw_response = qt.raw_response .. line .. "\n" end line = line:gsub("^data: ", "") - if line:match("chat%.completion%.chunk") then + if line:match("choices") and line:match("delta") and line:match("content") then line = vim.json.decode(line) - local content = line.choices[1].delta.content - if content ~= nil then - qt.response = qt.response .. content - handler(qid, content) + if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then + local content = line.choices[1].delta.content + if content and type(content) == "string" then + qt.response = qt.response .. content + handler(qid, content) + end end end end @@ -1239,7 +1263,17 @@ M.query = function(buf, payload, handler, on_exit) end -- try to replace model in endpoint (for azure) - local endpoint = M._H.template_replace(M.config.openai_api_endpoint, "{{model}}", payload.model) + local endpoint = M._H.template_replace(M.providers[provider].endpoint, "{{model}}", payload.model) + local bearer = M.providers[provider].secret + local headers = {} + if provider == "copilot" and M._state.copilot_bearer then + ---@diagnostic disable-next-line: undefined-field + bearer = M._state.copilot_bearer.token or "" + headers = { + "-H", + "editor-version: vscode/1.85.1", + } + end local curl_params = vim.deepcopy(M.config.curl_params or {}) local args = { @@ -1250,7 +1284,7 @@ M.query = function(buf, payload, handler, on_exit) "Content-Type: application/json", -- api-key is for azure, authorization is for openai "-H", - "Authorization: Bearer " .. M.config.openai_api_key, + "Authorization: Bearer " .. bearer, "-H", "api-key: " .. M.config.openai_api_key, "-d", @@ -1262,6 +1296,10 @@ M.query = function(buf, payload, handler, on_exit) table.insert(curl_params, arg) end + for _, header in ipairs(headers) do + table.insert(curl_params, header) + end + M._H.process(buf, "curl", curl_params, nil, out_reader(), nil) end @@ -2032,6 +2070,7 @@ M.chat_respond = function(params) -- call the model and write response M.query( buf, + agent.provider, M.prepare_payload(messages, headers.model, agent.model), M.create_handler(buf, win, M._H.last_content_line(buf), true, "", not M.config.chat_free_cursor), vim.schedule_wrap(function(qid) @@ -2074,6 +2113,7 @@ M.chat_respond = function(params) -- call the model M.query( nil, + agent.provider, M.prepare_payload(messages, nil, M.config.chat_topic_gen_model), topic_handler, vim.schedule_wrap(function() @@ -2501,7 +2541,14 @@ M.get_command_agent = function() local name = M._state.command_agent local model = M.agents[name].model local system_prompt = M.agents[name].system_prompt - return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt } + local provider = M.agents[name].provider + return { + cmd_prefix = cmd_prefix, + name = name, + model = model, + system_prompt = system_prompt, + provider = provider, + } end ---@return table # { cmd_prefix, name, model, system_prompt } @@ -2511,7 +2558,14 @@ M.get_chat_agent = function() local name = M._state.chat_agent local model = M.agents[name].model local system_prompt = M.agents[name].system_prompt - return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt } + local provider = M.agents[name].provider + return { + cmd_prefix = cmd_prefix, + name = name, + model = model, + system_prompt = system_prompt, + provider = provider, + } end M.cmd.Context = function(params) @@ -2556,7 +2610,7 @@ M.cmd.Context = function(params) M._H.feedkeys("G", "x") end -M.Prompt = function(params, target, prompt, model, template, system_template, whisper) +M.Prompt = function(params, target, prompt, model, template, system_template, whisper, provider) -- enew, new, vnew, tabnew should be resolved into table if type(target) == "function" then target = target() @@ -2824,6 +2878,7 @@ M.Prompt = function(params, target, prompt, model, template, system_template, wh local agent = M.get_command_agent() M.query( buf, + provider, M.prepare_payload(messages, model, agent.model), handler, vim.schedule_wrap(function(qid)