Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: DALL-E integration #498

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions command/openai/api.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package openai

import "github.com/pkg/errors"
import (
"bytes"
"net/http"

"github.com/pkg/errors"
)

const (
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiHost = "https://api.openai.com"
apiCompletionURL = "/v1/chat/completions"
apiDalleGenerateImageURL = "/v1/images/generations"
)

const (
@@ -13,6 +19,18 @@ const (
roleAssistant = "assistant"
)

func doRequest(cfg Config, apiEndpoint string, data []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", cfg.APIHost+apiEndpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)

return client.Do(req)
}

// https://platform.openai.com/docs/api-reference/chat
type ChatRequest struct {
Model string `json:"model"`
@@ -72,3 +90,52 @@ type ChatChoice struct {
FinishReason string `json:"finish_reason"`
Delta ChatMessage `json:"delta"`
}

/*
{
"model": "dall-e-3",
"prompt": "a white siamese cat",
"n": 1,
"size": "1024x1024"
}
*/
type DalleRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}

/*
{
"created": 1700233554,
"data": [
{
"url": "https://XXXX"
}
]
}

or:

{
"error": {
"code": "invalid_size",
"message": "The size is not supported by this model.",
"param": null,
"type": "invalid_request_error"
}
}
*/
type DalleResponse struct {
Data []DalleResponseImage `json:"data"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}

type DalleResponseImage struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt"`
}
28 changes: 9 additions & 19 deletions command/openai/client.go → command/openai/chatgpt.go
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@ package openai

import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
@@ -15,30 +14,21 @@ import (
var client http.Client

func CallChatGPT(cfg Config, inputMessages []ChatMessage, stream bool) (<-chan string, error) {
jsonData, _ := json.Marshal(ChatRequest{
Model: cfg.Model,
Temperature: cfg.Temperature,
Seed: cfg.Seed,
MaxTokens: cfg.MaxTokens,
Stream: stream,
Messages: inputMessages,
})

req, err := http.NewRequest("POST", cfg.APIHost+apiCompletionURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)

messageUpdates := make(chan string, 2)

// return a chan of all message updates here and listen here in the background in the event stream
go func() {
defer close(messageUpdates)

resp, err := client.Do(req)
jsonData, _ := json.Marshal(ChatRequest{
Model: cfg.Model,
Temperature: cfg.Temperature,
Seed: cfg.Seed,
MaxTokens: cfg.MaxTokens,
Stream: stream,
Messages: inputMessages,
})
resp, err := doRequest(cfg, apiCompletionURL, jsonData)
if err != nil {
messageUpdates <- err.Error()
return
2 changes: 2 additions & 0 deletions command/openai/command.go
Original file line number Diff line number Diff line change
@@ -53,6 +53,8 @@ func (c *chatGPTCommand) GetMatcher() matcher.Matcher {
matchers := []matcher.Matcher{
matcher.NewPrefixMatcher("openai", c.newConversation),
matcher.NewPrefixMatcher("chatgpt", c.newConversation),
matcher.NewPrefixMatcher("dalle", c.dalleGenerateImage),
matcher.NewPrefixMatcher("generate image", c.dalleGenerateImage),
matcher.WildcardMatcher(c.reply),
}

10 changes: 10 additions & 0 deletions command/openai/config.go
Original file line number Diff line number Diff line change
@@ -27,6 +27,11 @@ type Config struct {

// log all input+output text to the logger. This could include personal information, therefore disabled by default!
LogTexts bool `mapstructure:"log_texts"`

// Dall-E image generation
DalleModel string `mapstructure:"dalle_model"`
DalleImageSize string `mapstructure:"dalle_image_size"`
DalleNumberOfImages int `mapstructure:"dalle_number_of_images"`
}

// IsEnabled checks if token is set
@@ -40,6 +45,11 @@ var defaultConfig = Config{
UpdateInterval: time.Second,
HistorySize: 15,
InitialSystemMessage: "You are a helpful Slack bot. By default, keep your answer short and truthful",

// default dall-e config
DalleModel: "dall-e-3",
DalleImageSize: "1024x1024",
DalleNumberOfImages: 1,
}

func loadConfig(config *config.Config) Config {
69 changes: 69 additions & 0 deletions command/openai/dalle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package openai

import (
"encoding/json"
"fmt"
"time"

"github.com/innogames/slack-bot/v2/bot/matcher"
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/util"
log "github.com/sirupsen/logrus"
)

// bot function to generate images with Dall-E
func (c *chatGPTCommand) dalleGenerateImage(match matcher.Result, message msg.Message) {
prompt := match.GetString(util.FullMatch)

go func() {
c.AddReaction(":coffee:", message)
defer c.RemoveReaction(":coffee:", message)

images, err := generateImage(c.cfg, prompt)
if err != nil {
c.ReplyError(message, err)
return
}

text := ""
for _, image := range images {
text += fmt.Sprintf(
" - %s: <%s|open image>\n",
image.RevisedPrompt,
image.URL,
)
}
c.SendMessage(message, text)
}()
}

func generateImage(cfg Config, prompt string) ([]DalleResponseImage, error) {
jsonData, _ := json.Marshal(DalleRequest{
Model: cfg.DalleModel,
Size: cfg.DalleImageSize,
N: cfg.DalleNumberOfImages,
Prompt: prompt,
})

start := time.Now()
resp, err := doRequest(cfg, apiDalleGenerateImageURL, jsonData)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var response DalleResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return nil, err
}

if response.Error.Message != "" {
return nil, fmt.Errorf(response.Error.Message)
}

log.WithField("model", cfg.DalleModel).
Infof("Dall-E image generation took %s", time.Since(start))

return response.Data, nil
}
106 changes: 106 additions & 0 deletions command/openai/dalle_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package openai

import (
"net/http"
"testing"
"time"

"github.com/innogames/slack-bot/v2/bot"
"github.com/innogames/slack-bot/v2/bot/config"
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/innogames/slack-bot/v2/mocks"
"github.com/stretchr/testify/assert"
)

func TestDalle(t *testing.T) {
// init memory based storage
storage.InitStorage("")

slackClient := &mocks.SlackClient{}
base := bot.BaseCommand{SlackClient: slackClient}

t.Run("test http error", func(t *testing.T) {
ts := startTestServer(
t,
apiDalleGenerateImageURL,
[]testRequest{
{
`{"model":"dall-e-3","prompt":"a nice cat","n":1,"size":"1024x1024"}`,
`{
"error": {
"code": "invalid_api_key",
"message": "Incorrect API key provided: sk-1234**************************************567.",
"type": "invalid_request_error"
}
}`,
http.StatusUnauthorized,
},
},
)
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"

cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

defer ts.Close()

commands := GetCommands(base, cfg)

message := msg.Message{}
message.Text = "dalle a nice cat"

mocks.AssertReaction(slackClient, ":coffee:", message)
mocks.AssertRemoveReaction(slackClient, ":coffee:", message)
mocks.AssertError(slackClient, message, "Incorrect API key provided: sk-1234**************************************567.")

actual := commands.Run(message)
time.Sleep(time.Millisecond * 100)
assert.True(t, actual)
})

t.Run("test generate image", func(t *testing.T) {
ts := startTestServer(
t,
apiDalleGenerateImageURL,
[]testRequest{
{
`{"model":"dall-e-3","prompt":"a nice cat","n":1,"size":"1024x1024"}`,
` {
"created": 1700233554,
"data": [
{
"url": "https://example.com/image123",
"revised_prompt": "revised prompt 1234"
}
]
}`,
http.StatusUnauthorized,
},
},
)
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"

cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

defer ts.Close()

commands := GetCommands(base, cfg)

message := msg.Message{}
message.Text = "dalle a nice cat"

mocks.AssertReaction(slackClient, ":coffee:", message)
mocks.AssertRemoveReaction(slackClient, ":coffee:", message)
mocks.AssertSlackMessage(slackClient, message, " - revised prompt 1234: <https://example.com/image123|open image>\n")

actual := commands.Run(message)
time.Sleep(time.Millisecond * 100)
assert.True(t, actual)
})
}
13 changes: 11 additions & 2 deletions command/openai/openai_test.go
Original file line number Diff line number Diff line change
@@ -26,13 +26,13 @@ type testRequest struct {
responseCode int
}

func startTestServer(t *testing.T, requests []testRequest) *httptest.Server {
func startTestServer(t *testing.T, url string, requests []testRequest) *httptest.Server {
t.Helper()

idx := 0

mux := http.NewServeMux()
mux.HandleFunc(apiCompletionURL, func(res http.ResponseWriter, req *http.Request) {
mux.HandleFunc(url, func(res http.ResponseWriter, req *http.Request) {
expected := requests[idx]
idx++

@@ -63,6 +63,7 @@ func TestOpenai(t *testing.T) {
t.Run("Start a new thread and reply", func(t *testing.T) {
ts := startTestServer(
t,
apiCompletionURL,
[]testRequest{
{
`{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful Slack bot. By default, keep your answer short and truthful"},{"role":"user","content":"whats 1+1?"}],"stream":true}`,
@@ -104,6 +105,7 @@ data: [DONE]`,
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"
openaiCfg.LogTexts = true
cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

@@ -159,6 +161,7 @@ data: [DONE]`,
// mock openai API
ts := startTestServer(
t,
apiCompletionURL,
[]testRequest{
{
`{"model":"","messages":[{"role":"user","content":"whats 1+1?"}],"stream":true}`,
@@ -201,6 +204,8 @@ data: [DONE]`,
// mock openai API
ts := startTestServer(
t,
apiCompletionURL,

[]testRequest{
{
`{"model":"","messages":[{"role":"user","content":"whats 1+1?"}],"stream":true}`,
@@ -243,6 +248,7 @@ data: [DONE]`,
// mock openai API
ts := startTestServer(
t,
apiCompletionURL,
[]testRequest{
{
`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"whats 1+1?"}]}`,
@@ -285,6 +291,7 @@ data: [DONE]`,
t.Run("Write within a new thread", func(t *testing.T) {
ts := startTestServer(
t,
apiCompletionURL,
[]testRequest{
{
`{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful Slack bot. By default, keep your answer short and truthful"},{"role":"system","content":"This is a Slack bot receiving a slack thread s context, using slack user ids as identifiers. Please use user mentions in the format \u003c@U123456\u003e"},{"role":"user","content":"User \u003c@U1234\u003e wrote: thread message 1"},{"role":"user","content":"whats 1+1?"}],"stream":true}`,
@@ -304,6 +311,7 @@ data: [DONE]`,
openaiCfg := defaultConfig
openaiCfg.APIHost = ts.URL
openaiCfg.APIKey = "0815pass"
openaiCfg.LogTexts = true
cfg := &config.Config{}
cfg.Set("openai", openaiCfg)

@@ -344,6 +352,7 @@ data: [DONE]`,
t.Run("Other thread given", func(t *testing.T) {
ts := startTestServer(
t,
apiCompletionURL,
[]testRequest{
{
`{"model":"dummy-test","messages":[{"role":"user","content":"User \u003c@U1234\u003e wrote: i had a great weekend"},{"role":"user","content":"summarize this thread "}],"stream":true}`,
Binary file added docs/dalle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
@@ -286,7 +286,7 @@ User can define his default environment once by using `set variable serverEnviro
Then the `deploy feature-123` will deploy the branch to the defined `aws-02` environment.
Each user can define his own variables.

## Openai/ChatGPT integration
## Openai/ChatGPT/Dall-e integration
It's also possible to have a [ChatGPT](https://chat.openai.com) like conversation with the official OpenAI integration (GPT3.5)!

![openai](./docs/openai.png)
@@ -311,6 +311,13 @@ It also possible to use the function in the templates (like in custom commands o

`{{ openai "Say some short welcome words to @Jon_Doe"}}` would print something like `Hello Jon, welcome! How can I assist you today?`

### DALL-E integration

The bot is also able to generate images with the help of [DALL-E](https://openai.com/blog/dall-e/).
Just prefix you prompt with "dalle" and the bot will generate an image based on your text.

![dall-e](./docs/dalle.png)

## Quiz command
If you need a small break and want to play a little quiz game, you can do so by calling this command.
No more than 50 questions are allowed.