Skip to content

Commit

Permalink
feat: add support for ChatGPT-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
Trojan295 committed May 27, 2024
1 parent 14a06ad commit 563c88d
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 57 deletions.
5 changes: 5 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
services:
airplay:
build: .
env_file:
- .envrc
15 changes: 8 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ module github.com/Trojan295/discord-airplay
go 1.21

require (
github.com/bwmarrin/discordgo v0.27.2-0.20230922130345-1f0b57f11024
github.com/bwmarrin/discordgo v0.28.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/sashabaranov/go-openai v1.16.0
go.uber.org/zap v1.26.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
github.com/sashabaranov/go-openai v1.24.1
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d
)

require (
github.com/gorilla/websocket v1.5.0 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 // indirect
)
30 changes: 16 additions & 14 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bwmarrin/discordgo v0.27.2-0.20230922130345-1f0b57f11024 h1:fHuF+yROO5s6nrVaHu1HXjBmdqA1EtrrqrV/0PnfMr0=
github.com/bwmarrin/discordgo v0.27.2-0.20230922130345-1f0b57f11024/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4=
github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.15.3 h1:rzoNK9n+Cak+PM6OQ9puxDmFllxfnVea9StlmhglXqA=
github.com/sashabaranov/go-openai v1.15.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.16.0 h1:34W6WV84ey6OpW0p2UewZkdMu82AxGC+BzpU6iiauRw=
github.com/sashabaranov/go-openai v1.16.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI=
github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.25.0 h1:4Hvk6GtkucQ790dqmj7l1eEnRdKm3k3ZUrUMS2d5+5c=
go.uber.org/zap v1.25.0/go.mod h1:JIAUzQIH94IC4fOJQm7gMmBJP5k7wQfdcnYdPoEXJYk=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d h1:N0hmiNbwsSNwHBAvR3QB5w25pUwH4tK0Y/RltD1j1h4=
golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
4 changes: 2 additions & 2 deletions pkg/discord/interactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ func (handler *InteractionHandler) PlaySong(s *discordgo.Session, ic *discordgo.
discordgo.SelectMenu{
CustomID: "add_song_playlist",
Options: []discordgo.SelectMenuOption{
{Label: "Add song", Value: "song", Emoji: discordgo.ComponentEmoji{Name: "🎵"}},
{Label: "Add whole playlist", Value: "playlist", Emoji: discordgo.ComponentEmoji{Name: "🎶"}},
{Label: "Add song", Value: "song", Emoji: &discordgo.ComponentEmoji{Name: "🎵"}},
{Label: "Add whole playlist", Value: "playlist", Emoji: &discordgo.ComponentEmoji{Name: "🎶"}},
},
},
},
Expand Down
4 changes: 3 additions & 1 deletion pkg/discord/voicechat.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ func (session *DiscordVoiceChatSession) SendPlayMessage(channelID string, messag
}

func (session *DiscordVoiceChatSession) EditPlayMessage(channelID, messageID string, message *bot.PlayMessage) error {
embeds := []*discordgo.MessageEmbed{GeneratePlayingSongEmbed(message)}

_, err := session.discordSession.ChannelMessageEditComplex(&discordgo.MessageEdit{
ID: messageID,
Channel: channelID,
Embeds: []*discordgo.MessageEmbed{GeneratePlayingSongEmbed(message)},
Embeds: &embeds,
})
return err
}
Expand Down
110 changes: 77 additions & 33 deletions pkg/sources/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,39 @@ package sources

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"time"

"github.com/sashabaranov/go-openai"
"golang.org/x/exp/slog"
)

const AssistantID = "asst_XaJbt5qXnInquWa1K6wNUYUT"

type ChatGPTPlaylistGenerator struct {
Logger *slog.Logger
Logger *slog.Logger

openAIClient *openai.Client
model string
assistantID string
}

func NewChatGPTPlaylistGenerator(token string) *ChatGPTPlaylistGenerator {
client := openai.NewClient(token)
config := openai.DefaultConfig(token)
config.AssistantVersion = "v2"
client := openai.NewClientWithConfig(config)
return &ChatGPTPlaylistGenerator{
Logger: slog.Default(),
openAIClient: client,
model: openai.GPT4o,
assistantID: AssistantID,
}
}

type PlaylistParams struct {
Description string
Length int
Description string `json:"description"`
Length int `json:"length"`
}

type PlaylistResponse struct {
Expand All @@ -42,44 +51,79 @@ func (g *ChatGPTPlaylistGenerator) GeneratePlaylist(ctx context.Context, params
params.Length = 10
}

resp, err := g.openAIClient.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "I want you to act as a DJ. I will provide you with a description of a playlist and number of songs, and you will create it for me. You should output the list of songs, each in a new line with artist and title. Add also some nice introduction before the song list, max 250 characters. Do not include any additional information or description, simply output: <introduction>\n<song number>. <artist> - <title>",
},
messageContent, err := json.Marshal(params)
if err != nil {
return nil, fmt.Errorf("while marshaling playlist params: %w", err)
}

run, err := g.openAIClient.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{
RunRequest: openai.RunRequest{
Model: g.model,
AssistantID: g.assistantID,
},
Thread: openai.ThreadRequest{
Messages: []openai.ThreadMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("Create for me a playlist of %d songs of %s", params.Length, params.Description),
Role: openai.ThreadMessageRoleUser,
Content: string(messageContent),
},
},
},
)
})
if err != nil {
return nil, fmt.Errorf("while creating chat completion: %w", err)
return nil, fmt.Errorf("while creating thread and running: %w", err)
}

g.Logger.Debug("chat completion completed", "respose", resp.Choices[0].Message.Content, "prompt tokens", resp.Usage.PromptTokens, "output tokens", resp.Usage.CompletionTokens)
runWaitCtx, cancel := context.WithTimeout(ctx, time.Duration(15*time.Second))
defer cancel()

lines := strings.Split(resp.Choices[0].Message.Content, "\n")
runCh := make(chan *openai.Run)
go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
default:
run, err := g.openAIClient.RetrieveRun(ctx, run.ThreadID, run.ID)
if err != nil {
time.Sleep(1 * time.Second)
continue
}

regex := regexp.MustCompile(`^\d+\.(.+)$`)
if run.Status == openai.RunStatusCompleted {
runCh <- &run
return
}

playlist := make([]string, 0)
for _, line := range lines {
matches := regex.FindStringSubmatch(line)
if matches == nil {
continue
time.Sleep(1 * time.Second)
}
}
}(runWaitCtx)

playlist = append(playlist, strings.TrimSpace(matches[1]))
}
select {
case <-runWaitCtx.Done():
return nil, fmt.Errorf("timeout while waiting for run to complete")

return &PlaylistResponse{
Intro: lines[0],
Playlist: playlist,
}, nil
case run := <-runCh:
messageList, err := g.openAIClient.ListMessage(ctx, run.ThreadID, nil, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("while listing messages: %w", err)
}

if len(messageList.Messages) != 2 {
return nil, fmt.Errorf("unexpected number of messages: %d", len(messageList.Messages))
}

if len(messageList.Messages[0].Content) != 1 {
return nil, fmt.Errorf("unexpected number of message content: %d", len(messageList.Messages[0].Content))
}

response := &PlaylistResponse{}
responseText := messageList.Messages[0].Content[0].Text.Value
if err := json.Unmarshal([]byte(responseText), response); err != nil {
return nil, fmt.Errorf("while unmarshaling response: %w", err)
}

return response, nil
}
}

0 comments on commit 563c88d

Please sign in to comment.