Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: openai support custom assistant/GPT #519

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions command/openai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
const (
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiThreadsURL = "/v1/threads"
apiDalleGenerateImageURL = "/v1/images/generations"
)

Expand All @@ -25,14 +26,15 @@ var client = http.Client{
Timeout: 60 * time.Second,
}

func doRequest(cfg Config, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
func doRequest(cfg Config, method string, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest(method, cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("OpenAI-Beta", "assistants=v1")

return client.Do(req)
}
Expand Down
300 changes: 300 additions & 0 deletions command/openai/assistant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package openai

import (
"encoding/json"
"fmt"

Check failure on line 5 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofumpt`-ed (gofumpt)
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/slack-go/slack"
"io"

Check failure on line 9 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofumpt`-ed (gofumpt)
"time"
)

// see https://platform.openai.com/docs/assistants/how-it-works

type assistantThreadResponse struct {
Id string `json:"id"`

Check warning on line 16 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
}
type assistantStartRun struct {
AssistantId string `json:"assistant_id"`

Check warning on line 19 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field AssistantId should be AssistantID (revive)
}

type run struct {
Id string `json:"id"`

Check warning on line 23 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
Status string `json:"status"`
ThreadId string `json:"thread_id"`

Check warning on line 25 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field ThreadId should be ThreadID (revive)
RequiredAction AssistantRequiredAction `json:"required_action"`
}

type AssistantRequiredAction struct {
Type string `json:"type"`
SubmitToolsOutputs struct {
ToolCalls []struct {
Id string `json:"id"`

Check warning on line 33 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field Id should be ID (revive)
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"submit_tool_outputs"`
}

type AssistantContent struct {
Type string `json:"type"`
Text struct {
Value string `json:"value"`
} `json:"text"`
}

func (c AssistantContent) GetText() string {
return c.Text.Value
}

type AssistantStartThreads struct {
Messages []ChatMessage `json:"messages"`
}

type AssistantChatMessage struct {
Id string `json:"id"`
Role string `json:"role"`
ChatMessage []AssistantContent `json:"content"`
RunId string `json:"run_id"`

Check warning on line 62 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field RunId should be RunID (revive)
}
type assistantFullResponse struct {
Data []AssistantChatMessage `json:"data"`
}

type AssistantToolsOutput struct {
ToolsOutput []struct {
ToolCallId string `json:"tool_call_id"`

Check warning on line 70 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: struct field ToolCallId should be ToolCallID (revive)
Output string `json:"output"`
} `json:"tool_outputs"`
}

func (c *openaiCommand) callCustomGPT(messages []ChatMessage, identifier string, message msg.Ref, text string) {
c.AddReaction(":coffee:", message)
defer c.RemoveReaction(":coffee:", message)

messages = append(messages, ChatMessage{
Role: roleUser,
Content: text,
})

var threadId string

Check warning on line 84 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var threadId should be threadID (revive)
var err error
storage.Read("gpt-thread", identifier, &threadId)

Check failure on line 86 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Read` is not checked (errcheck)
if threadId == "" {
// start a new thread!
threadId, err = createAssistantThread(c.cfg, messages)
if err != nil {
c.ReplyError(message, err)
return
}
storage.Write("gpt-thread", identifier, threadId)

Check failure on line 94 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Write` is not checked (errcheck)
} else {
// attach slack messages to an existing thread
for _, newMessage := range messages {
// todo no API to bulk add?!
addMessage(c.cfg, threadId, newMessage)

Check failure on line 99 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value is not checked (errcheck)
}
}

// start the assistant and get a "run" object
run, err := assistantRun(c.cfg, threadId)
if err != nil {
c.ReplyError(message, err)
return
}

// wait till run is done or required more information from function calls!
// see https://platform.openai.com/docs/assistants/how-it-works/run-lifecycle
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
for range ticker.C {
run, err = getRun(c.cfg, run)
if err != nil || run.Status == "failed" || run.Status == "cancelled" || run.Status == "expired" {
c.ReplyError(message, fmt.Errorf("run failed with status %s", run.Status))
return
}

if run.Status == "completed" {
// we have the final answer!
break
}

if run.Status == "requires_action" {
// todo extract code!
fmt.Println(run.RequiredAction)
fmt.Println(run.RequiredAction.SubmitToolsOutputs)
tool := run.RequiredAction.SubmitToolsOutputs.ToolCalls[0]

var output string
if tool.Function.Name == "dall_image" {
// special function
prompt := tool.Function.Arguments
fmt.Println(prompt, "prompt")

images, _ := generateImages(c.cfg, prompt)
output = images[0].RevisedPrompt
go c.sendImageInSlack(images[0], message)

Check failure on line 140 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `c.sendImageInSlack` is not checked (errcheck)
} else {
output = "Ticket: Fix issue in feature XYZ, status = open" // todo call function
}

sendToolsOutput(c.cfg, run, tool.Id, output)

Check failure on line 145 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value is not checked (errcheck)

// wait for new tick, as the API is handling the new information now...
continue
}
}

// todo only fetch the new messages for this run
respMessages, _ := listMessages(c.cfg, threadId)
for _, m := range respMessages {
if m.RunId != run.Id {
continue
}
fmt.Println(m.ChatMessage)
if m.Role != roleAssistant {
continue
}

// reply in thread
c.SendMessage(
message,
m.ChatMessage[0].GetText(),
slack.MsgOptionTS(message.GetTimestamp()),
)
}
}

/*
func (c *openaiCommand) assistantUploadFile(cfg Config, file slack.File) error {
var buf bytes.Buffer
log.Infof("Downloading message attachment file %s", file.Name)

fmt.Println(file)

resp, err := doRequest(cfg, "POST", apiFilesURL, []byte("jolo"))
if err != nil {
return nil
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

return nil
}
*/

func assistantRun(cfg Config, threadId string) (*run, error) {

Check warning on line 191 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: func parameter threadId should be threadID (revive)
fmt.Printf("run assistant %s\n", threadId)

assistantStartRun := assistantStartRun{
AssistantId: cfg.CustomGPT,
}

req, _ := json.Marshal(assistantStartRun)
resp, err := doRequest(cfg, "POST", apiThreadsURL+"/"+threadId+"/runs", req)
if err != nil {
return nil, err
}

run := &run{}
err = json.NewDecoder(resp.Body).Decode(run)
return run, err
}

func addMessage(cfg Config, threadId string, message ChatMessage) error {
fmt.Printf("add message to thread %s: %s\n", threadId, message)

req, _ := json.Marshal(message)
_, err := doRequest(cfg, "POST", apiThreadsURL+"/"+threadId+"/messages", req)

return err
}

func createAssistantThread(cfg Config, messages []ChatMessage) (string, error) {
fmt.Println("create thread")

req, _ := json.Marshal(AssistantStartThreads{
Messages: messages,
})
fmt.Println(string(req))
resp, err := doRequest(cfg, "POST", apiThreadsURL, req)
if err != nil {
return "", err
}
//r, _ := io.ReadAll(resp.Body)

Check failure on line 229 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

commentFormatting: put a space between `//` and comment text (gocritic)
//fmt.Println(string(r))
thread := assistantThreadResponse{}
err = json.NewDecoder(resp.Body).Decode(&thread)
if err != nil {
return "", err
}
fmt.Println(thread)

if thread.Id == "" {
return "", fmt.Errorf("failed to create thread")
}
return thread.Id, nil
}

func getRun(cfg Config, oldRun *run) (*run, error) {
fmt.Printf("get run %s %s\n", oldRun.ThreadId, oldRun.Id)
resp, err := doRequest(cfg, "GET", apiThreadsURL+"/"+oldRun.ThreadId+"/runs/"+oldRun.Id, nil)
if err != nil {
return oldRun, err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

newRun := &run{}
err = json.Unmarshal(r, newRun)

return newRun, err
}

func listMessages(cfg Config, threadId string) ([]AssistantChatMessage, error) {
fmt.Printf("list messages %s \n", threadId)
resp, err := doRequest(cfg, "GET", apiThreadsURL+"/"+threadId+"/messages", nil)
if err != nil {
return []AssistantChatMessage{}, err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

messages := assistantFullResponse{}
json.Unmarshal(r, &messages)

Check failure on line 271 in command/openai/assistant.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `json.Unmarshal` is not checked (errcheck)

return messages.Data, nil
}

func sendToolsOutput(cfg Config, run *run, callId string, output string) error {
fmt.Printf("send tools output %s %s %s\n", run.ThreadId, run.Id, callId)

req, _ := json.Marshal(AssistantToolsOutput{
ToolsOutput: []struct {
ToolCallId string `json:"tool_call_id"`
Output string `json:"output"`
}{
{
ToolCallId: callId,
Output: output,
},
},
})
fmt.Println(string(req))
resp, err := doRequest(cfg, "POST", apiThreadsURL+"/"+run.ThreadId+"/runs/"+run.Id+"/submit_tool_outputs", req)
if err != nil {
return err
}

r, _ := io.ReadAll(resp.Body)
fmt.Println(string(r))

return err
}
2 changes: 1 addition & 1 deletion command/openai/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func CallChatGPT(cfg Config, inputMessages []ChatMessage, stream bool) (<-chan s
Stream: stream,
Messages: inputMessages,
})
resp, err := doRequest(cfg, apiCompletionURL, jsonData)
resp, err := doRequest(cfg, "POST", apiCompletionURL, jsonData)
if err != nil {
messageUpdates <- err.Error()
return
Expand Down
18 changes: 15 additions & 3 deletions command/openai/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@
// bot function which is called, when the user started a new conversation with openai/chatgpt
func (c *openaiCommand) newConversation(match matcher.Result, message msg.Message) {
text := match.GetString(util.FullMatch)
c.startConversation(message.MessageRef, text)
c.startConversation(message, text)
}

func (c *openaiCommand) startConversation(message msg.Ref, text string) bool {
messageHistory := make([]ChatMessage, 0)

if c.cfg.InitialSystemMessage != "" {
if c.cfg.InitialSystemMessage != "" && c.cfg.CustomGPT == "" {
messageHistory = append(messageHistory, ChatMessage{
Role: roleSystem,
Content: c.cfg.InitialSystemMessage,
Expand Down Expand Up @@ -135,7 +135,12 @@
storageIdentifier = getIdentifier(message.GetChannel(), message.GetTimestamp())
}

c.callAndStore(messageHistory, storageIdentifier, message, text)
if c.cfg.CustomGPT != "" {
c.callCustomGPT(messageHistory, storageIdentifier, message, text)
} else {
// usual GPT-X model
c.callAndStore(messageHistory, storageIdentifier, message, text)
}
return true
}

Expand All @@ -149,6 +154,13 @@
// Load the chat history from storage.
identifier := getIdentifier(message.GetChannel(), message.GetThread())

var threadId string

Check warning on line 157 in command/openai/command.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var threadId should be threadID (revive)
storage.Read("gpt-thread", identifier, &threadId)

Check failure on line 158 in command/openai/command.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `storage.Read` is not checked (errcheck)
if threadId != "" {
c.callCustomGPT([]ChatMessage{}, identifier, message, text)
return true
}

var messages []ChatMessage
err := storage.Read(storageKey, identifier, &messages)
if err != nil || len(messages) == 0 {
Expand Down
2 changes: 2 additions & 0 deletions command/openai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type Config struct {
DalleModel string `mapstructure:"dalle_model"`
DalleImageSize string `mapstructure:"dalle_image_size"`
DalleNumberOfImages int `mapstructure:"dalle_number_of_images"`

CustomGPT string `mapstructure:"custom_gpt"`
}

// IsEnabled checks if token is set
Expand Down
Loading
Loading