diff --git a/model/rag/chat.go b/model/rag/chat.go index d146d5da332..8aecc5bf1dd 100644 --- a/model/rag/chat.go +++ b/model/rag/chat.go @@ -26,6 +26,8 @@ import ( type ChatPayload struct { ChatConversationID string Query string `json:"q"` + Stream *bool `json:"stream"` + WebSearch *bool `json:"websearch"` } type ChatConversation struct { @@ -73,8 +75,10 @@ func (c *ChatConversation) Links() *jsonapi.LinksList { return nil var _ jsonapi.Object = (*ChatConversation)(nil) type QueryMessage struct { - Task string `json:"task"` - DocID string `json:"doc_id"` + Task string `json:"task"` + DocID string `json:"doc_id"` + Stream bool `json:"stream"` + WebSearch bool `json:"websearch"` } type Source struct { @@ -121,9 +125,19 @@ func Chat(inst *instance.Instance, payload ChatPayload) (*ChatConversation, erro if err != nil { return nil, err } + stream := true + if payload.Stream != nil { + stream = *payload.Stream + } + websearch := false + if payload.WebSearch != nil { + websearch = *payload.WebSearch + } query, err := job.NewMessage(&QueryMessage{ - Task: "chat-completion", - DocID: chat.DocID, + Task: "chat-completion", + DocID: chat.DocID, + Stream: stream, + WebSearch: websearch, }) if err != nil { return nil, err @@ -203,18 +217,22 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er Content string `json:"content"` } - chat_history := make([]RAGMessage, 0, len(chat.Messages)) + chatHistory := make([]RAGMessage, 0, len(chat.Messages)) for _, msg := range chat.Messages { - chat_history = append(chat_history, RAGMessage{ + chatHistory = append(chatHistory, RAGMessage{ Role: msg.Role, Content: msg.Content, }) } + metadata := map[string]interface{}{ + "websearch": query.WebSearch, + } payload := map[string]interface{}{ "model": fmt.Sprintf("ragondin-%s", inst.Domain), - "messages": chat_history, - "stream": true, + "messages": chatHistory, + "stream": query.Stream, + "metadata": metadata, "temperature": Temperature, "top_p": TopP, "logprobs": LogProbs, @@ -234,14 +252,89 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er } msg := chat.Messages[len(chat.Messages)-1] + var completion string + var sources []Source + + if query.Stream { + completion, sources, err = handleStreamResponse(inst, msg, res.Body) + } else { + completion, sources, err = handleNonStreamResponse(inst, msg, res.Body) + } + if err != nil { + // Send error event to client + errorDoc := map[string]interface{}{ + "_id": msg.ID, + "object": "error", + "message": err.Error(), + } + errorPayload := couchdb.JSONDoc{ + Type: consts.ChatEvents, + M: errorDoc, + } + go realtime.GetHub().Publish(inst, realtime.EventCreate, &errorPayload, nil) + return err + } + + uuidv7, _ := uuid.NewV7() + answer := ChatMessage{ + ID: uuidv7.String(), + Role: AssistantRole, + Content: completion, + Sources: sources, + CreatedAt: time.Now().UTC(), + } + chat.Messages = append(chat.Messages, answer) + return couchdb.UpdateDoc(inst, &chat) +} + +func publishDelta(inst *instance.Instance, msgID string, content string, position int) { + doc := couchdb.JSONDoc{ + Type: consts.ChatEvents, + M: map[string]interface{}{ + "_id": msgID, + "object": "delta", + "content": content, + "position": position, + }, + } + doc.SetID(msgID) + realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil) +} + +func publishSources(inst *instance.Instance, msgID string, sources []Source) { + doc := couchdb.JSONDoc{ + Type: consts.ChatEvents, + M: map[string]interface{}{ + "_id": msgID, + "object": "sources", + "content": sources, + }, + } + doc.SetID(msgID) + realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil) +} + +func publishDone(inst *instance.Instance, msgID string) { + doc := couchdb.JSONDoc{ + Type: consts.ChatEvents, + M: map[string]interface{}{ + "_id": msgID, + "object": "done", + }, + } + doc.SetID(msgID) + realtime.GetHub().Publish(inst, realtime.EventCreate, &doc, nil) +} + +func handleStreamResponse(inst *instance.Instance, msg ChatMessage, body io.Reader) (string, []Source, error) { position := 0 var completion string var sources []Source - var sourcePayload couchdb.JSONDoc + var sseErr error // Realtime messages are sent to the client during the response stream // When the stream is finished, the whole answer is saved in the CouchDB document - err = foreachSSE(res.Body, func(event map[string]interface{}) { + err := foreachSSE(body, func(event map[string]interface{}) { // See https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming if event["object"] == "chat.completion.chunk" { choices, ok := event["choices"].([]interface{}) @@ -249,84 +342,72 @@ func Query(inst *instance.Instance, logger logger.Logger, query QueryMessage) er return } choice := choices[0].(map[string]interface{}) // Only one choice is possible for now - var doc map[string]interface{} if reason, ok := choice["finish_reason"].(string); ok && reason != "" { - // The response stream is finished - doc = map[string]interface{}{ - "_id": msg.ID, - "object": "done", - } + go publishDone(inst, msg.ID) } else if delta, ok := choice["delta"].(map[string]interface{}); ok { // The content is progressively reveived through a delta stream content, ok := delta["content"].(string) if !ok { return } - doc = map[string]interface{}{ - "_id": msg.ID, - "object": "delta", - "content": content, - "position": position, - } + go publishDelta(inst, msg.ID, content, position) completion += content position++ if event["extra"].(string) != "" && sources == nil { // Sources are included in all delta messages, but should be sent once - sources, err = getSources(event) - if err != nil { + sources, sseErr = getSources(event) + if sseErr != nil { return } if sources != nil { - sourceDoc := map[string]interface{}{ - "_id": msg.ID, - "object": "sources", - "content": sources, - } - sourcePayload = couchdb.JSONDoc{ - Type: consts.ChatEvents, - M: sourceDoc, - } - go realtime.GetHub().Publish(inst, realtime.EventCreate, &sourcePayload, nil) + go publishSources(inst, msg.ID, sources) } } } - - payload := couchdb.JSONDoc{ - Type: consts.ChatEvents, - M: doc, - } - payload.SetID(msg.ID) - go realtime.GetHub().Publish(inst, realtime.EventCreate, &payload, nil) } }) if err != nil { - // Send error event to client - errorDoc := map[string]interface{}{ - "_id": msg.ID, - "object": "error", - "message": err.Error(), - } - errorPayload := couchdb.JSONDoc{ - Type: consts.ChatEvents, - M: errorDoc, + return "", nil, err + } + if sseErr != nil { + return "", nil, sseErr + } + return completion, sources, nil +} + +func handleNonStreamResponse(inst *instance.Instance, msg ChatMessage, body io.Reader) (string, []Source, error) { + var event map[string]interface{} + if err := json.NewDecoder(body).Decode(&event); err != nil { + return "", nil, err + } + + var completion string + if choices, ok := event["choices"].([]interface{}); ok && len(choices) > 0 { + if choice, ok := choices[0].(map[string]interface{}); ok { + if message, ok := choice["message"].(map[string]interface{}); ok { + completion, _ = message["content"].(string) + } } - go realtime.GetHub().Publish(inst, realtime.EventCreate, &errorPayload, nil) - return err + } + if completion == "" { + return "", nil, errors.New("invalid RAG response: no completion content") } - uuidv7, _ := uuid.NewV7() - answer := ChatMessage{ - ID: uuidv7.String(), - Role: AssistantRole, - Content: completion, - Sources: sources, - CreatedAt: time.Now().UTC(), + sources, err := getSources(event) + if err != nil { + return "", nil, err } - chat.Messages = append(chat.Messages, answer) - return couchdb.UpdateDoc(inst, &chat) + + publishDelta(inst, msg.ID, completion, 0) + if sources != nil { + publishSources(inst, msg.ID, sources) + } + publishDone(inst, msg.ID) + + return completion, sources, nil } func CallRAGQuery(inst *instance.Instance, payload []byte, path string, contentType string) (*http.Response, error) {