Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 139 additions & 61 deletions model/rag/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
type ChatPayload struct {
ChatConversationID string
Query string `json:"q"`
Stream *bool `json:"stream"`
WebSearch *bool `json:"websearch"`
}

type ChatConversation struct {
Expand Down Expand Up @@ -73,8 +75,10 @@ func (c *ChatConversation) Links() *jsonapi.LinksList { return nil
var _ jsonapi.Object = (*ChatConversation)(nil)

type QueryMessage struct {
Task string `json:"task"`
DocID string `json:"doc_id"`
Task string `json:"task"`
DocID string `json:"doc_id"`
Stream bool `json:"stream"`
WebSearch bool `json:"websearch"`
}

type Source struct {
Expand Down Expand Up @@ -121,9 +125,19 @@ func Chat(inst *instance.Instance, payload ChatPayload) (*ChatConversation, erro
if err != nil {
return nil, err
}
stream := true
if payload.Stream != nil {
stream = *payload.Stream
}
websearch := false
if payload.WebSearch != nil {
websearch = *payload.WebSearch
}
query, err := job.NewMessage(&QueryMessage{
Task: "chat-completion",
DocID: chat.DocID,
Task: "chat-completion",
DocID: chat.DocID,
Stream: stream,
WebSearch: websearch,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -203,18 +217,19 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er
Content string `json:"content"`
}

chat_history := make([]RAGMessage, 0, len(chat.Messages))
chatHistory := make([]RAGMessage, 0, len(chat.Messages))
for _, msg := range chat.Messages {
chat_history = append(chat_history, RAGMessage{
chatHistory = append(chatHistory, RAGMessage{
Role: msg.Role,
Content: msg.Content,
})
}

payload := map[string]interface{}{
"model": fmt.Sprintf("ragondin-%s", inst.Domain),
"messages": chat_history,
"stream": true,
"messages": chatHistory,
"stream": query.Stream,
"websearch": query.WebSearch,
"temperature": Temperature,
"top_p": TopP,
"logprobs": LogProbs,
Expand All @@ -234,99 +249,162 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er
}

msg := chat.Messages[len(chat.Messages)-1]
var completion string
var sources []Source

if query.Stream {
completion, sources, err = handleStreamResponse(inst, msg, res.Body)
} else {
completion, sources, err = handleNonStreamResponse(inst, msg, res.Body)
}
if err != nil {
// Send error event to client
errorDoc := map[string]interface{}{
"_id": msg.ID,
"object": "error",
"message": err.Error(),
}
errorPayload := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: errorDoc,
}
go realtime.GetHub().Publish(inst, realtime.EventCreate, &errorPayload, nil)
return err
}

uuidv7, _ := uuid.NewV7()
answer := ChatMessage{
ID: uuidv7.String(),
Role: AssistantRole,
Content: completion,
Sources: sources,
CreatedAt: time.Now().UTC(),
}
chat.Messages = append(chat.Messages, answer)
return couchdb.UpdateDoc(inst, &chat)
}

func publishDelta(inst *instance.Instance, msgID string, content string, position int) {
doc := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: map[string]interface{}{
"_id": msgID,
"object": "delta",
"content": content,
"position": position,
},
}
doc.SetID(msgID)
realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil)
}

func publishSources(inst *instance.Instance, msgID string, sources []Source) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have doc.SetID(msgID) in publishDelta, should we set id here as well?

doc := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: map[string]interface{}{
"_id": msgID,
"object": "sources",
"content": sources,
},
}
doc.SetID(msgID)
realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil)
}

func publishDone(inst *instance.Instance, msgID string) {
doc := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: map[string]interface{}{
"_id": msgID,
"object": "done",
},
}
doc.SetID(msgID)
realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil)
}

func handleStreamResponse(inst *instance.Instance, msg ChatMessage, body io.Reader) (string, []Source, error) {
position := 0
var completion string
var sources []Source
var sourcePayload couchdb.JSONDoc
var sseErr error

// Realtime messages are sent to the client during the response stream
// When the stream is finished, the whole answer is saved in the CouchDB document
err = foreachSSE(res.Body, func(event map[string]interface{}) {
err := foreachSSE(body, func(event map[string]interface{}) {
// See https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming
if event["object"] == "chat.completion.chunk" {
choices, ok := event["choices"].([]interface{})
if !ok || len(choices) < 1 {
return
}
choice := choices[0].(map[string]interface{}) // Only one choice is possible for now
var doc map[string]interface{}

if reason, ok := choice["finish_reason"].(string); ok && reason != "" {
// The response stream is finished
doc = map[string]interface{}{
"_id": msg.ID,
"object": "done",
}
go publishDone(inst, msg.ID)
} else if delta, ok := choice["delta"].(map[string]interface{}); ok {
// The content is progressively reveived through a delta stream
content, ok := delta["content"].(string)
if !ok {
return
}
doc = map[string]interface{}{
"_id": msg.ID,
"object": "delta",
"content": content,
"position": position,
}
go publishDelta(inst, msg.ID, content, position)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have handling of "object": "error" anymore, is it intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I actually had a local checkout without this... It is back now

completion += content
position++

if event["extra"].(string) != "" && sources == nil {
// Sources are included in all delta messages, but should be sent once
sources, err = getSources(event)
if err != nil {
sources, sseErr = getSources(event)
if sseErr != nil {
return
}
if sources != nil {
sourceDoc := map[string]interface{}{
"_id": msg.ID,
"object": "sources",
"content": sources,
}
sourcePayload = couchdb.JSONDoc{
Type: consts.ChatEvents,
M: sourceDoc,
}
go realtime.GetHub().Publish(inst, realtime.EventCreate, &sourcePayload, nil)
go publishSources(inst, msg.ID, sources)
}
}
}

payload := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: doc,
}
payload.SetID(msg.ID)
go realtime.GetHub().Publish(inst, realtime.EventCreate, &payload, nil)
}
})

if err != nil {
// Send error event to client
errorDoc := map[string]interface{}{
"_id": msg.ID,
"object": "error",
"message": err.Error(),
}
errorPayload := couchdb.JSONDoc{
Type: consts.ChatEvents,
M: errorDoc,
return "", nil, err
}
if sseErr != nil {
return "", nil, sseErr
}
return completion, sources, nil
}

func handleNonStreamResponse(inst *instance.Instance, msg ChatMessage, body io.Reader) (string, []Source, error) {
var event map[string]interface{}
if err := json.NewDecoder(body).Decode(&event); err != nil {
return "", nil, err
}

var completion string
if choices, ok := event["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := choice["message"].(map[string]interface{}); ok {
completion, _ = message["content"].(string)
}
}
go realtime.GetHub().Publish(inst, realtime.EventCreate, &errorPayload, nil)
return err
}
if completion == "" {
return "", nil, errors.New("invalid RAG response: no completion content")
}

uuidv7, _ := uuid.NewV7()
answer := ChatMessage{
ID: uuidv7.String(),
Role: AssistantRole,
Content: completion,
Sources: sources,
CreatedAt: time.Now().UTC(),
sources, err := getSources(event)
if err != nil {
return "", nil, err
}
chat.Messages = append(chat.Messages, answer)
return couchdb.UpdateDoc(inst, &chat)

publishDelta(inst, msg.ID, completion, 0)
if sources != nil {
publishSources(inst, msg.ID, sources)
}
publishDone(inst, msg.ID)

return completion, sources, nil
}

func CallRAGQuery(inst *instance.Instance, payload []byte, path string, contentType string) (*http.Response, error) {
Expand Down
Loading