From 5c9d6eeb7e04ec1bea9b8a417b26bb32a658a91c Mon Sep 17 00:00:00 2001 From: wurui Date: Fri, 17 May 2024 09:54:46 +0800 Subject: [PATCH] feat: add transID for ai function (#817) # Description Add transID for ai function, each request will have an unique transID, and we can use `FromTransIDContext` to retrieve it from the `ctx`. --- ai/model.go | 1 - pkg/bridge/ai/api_server.go | 56 ++++++++++++++++++++++--------------- pkg/bridge/ai/service.go | 35 ++++++++++++----------- 3 files changed, 53 insertions(+), 39 deletions(-) diff --git a/ai/model.go b/ai/model.go index 107294009..69a2fe50d 100644 --- a/ai/model.go +++ b/ai/model.go @@ -15,7 +15,6 @@ type OverviewResponse struct { // InvokeRequest is the request from user to BasicAPIServer type InvokeRequest struct { - ReqID string `json:"req_id"` // ReqID is the request id of the request Prompt string `json:"prompt"` // Prompt is user input text for chat completion IncludeCallStack bool `json:"include_call_stack"` // IncludeCallStack is the flag to include call stack in response } diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index 04b96ae53..26a2017fa 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -9,10 +9,10 @@ import ( "net/http" "time" - gonanoid "github.com/matoous/go-nanoid/v2" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/ylog" + "github.com/yomorun/yomo/pkg/id" ) const ( @@ -97,7 +97,10 @@ func WithContextService(handler http.Handler, credential string, zipperAddr stri } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler.ServeHTTP(w, r.WithContext(WithServiceContext(r.Context(), service))) + transID := id.New(32) + ctx := WithTransIDContext(r.Context(), transID) + ctx = WithServiceContext(ctx, service) + handler.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -119,18 +122,14 @@ func HandleOverview(w http.ResponseWriter, r *http.Request) { // HandleInvoke is the handler for POST /invoke func HandleInvoke(w http.ResponseWriter, r *http.Request) { - service := FromServiceContext(r.Context()) + var ( + ctx = r.Context() + service = FromServiceContext(ctx) + transID = FromTransIDContext(ctx) + ) defer r.Body.Close() - reqID, err := gonanoid.New(6) - if err != nil { - ylog.Error("generate reqID", "err", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - return - } var req ai.InvokeRequest - req.ReqID = reqID // decode the request if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -154,8 +153,8 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) { errCh := make(chan error, 1) go func(service *Service, req ai.InvokeRequest, baseSystemMessage string) { // call llm to infer the function and arguments to be invoked - ylog.Debug(">> ai request", "reqID", req.ReqID, "prompt", req.Prompt) - res, err := service.GetInvoke(ctx, req.Prompt, baseSystemMessage, req.ReqID, req.IncludeCallStack) + ylog.Debug(">> ai request", "transID", transID, "prompt", req.Prompt) + res, err := service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack) if err != nil { errCh <- err } else { @@ -183,16 +182,13 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) { // HandleChatCompletions is the handler for POST /chat/completion func HandleChatCompletions(w http.ResponseWriter, r *http.Request) { - service := FromServiceContext(r.Context()) + var ( + ctx = r.Context() + service = FromServiceContext(ctx) + transID = FromTransIDContext(ctx) + ) defer r.Body.Close() - reqID, err := gonanoid.New(6) - if err != nil { - ylog.Error("generate reqID", "err", err.Error()) - RespondWithError(w, http.StatusInternalServerError, err) - return - } - var req openai.ChatCompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { ylog.Error("decode request", "err", err.Error()) @@ -203,7 +199,7 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 90*time.Second) defer cancel() - if err := service.GetChatCompletions(ctx, req, reqID, w, false); err != nil { + if err := service.GetChatCompletions(ctx, req, transID, w, false); err != nil { ylog.Error("invoke chat completions", "err", err.Error()) RespondWithError(w, http.StatusBadRequest, err) return @@ -248,3 +244,19 @@ func FromServiceContext(ctx context.Context) *Service { } return service } + +type transIDContextKey struct{} + +// WithTransIDContext adds the transID to the request context +func WithTransIDContext(ctx context.Context, transID string) context.Context { + return context.WithValue(ctx, transIDContextKey{}, transID) +} + +// FromTransIDContext returns the transID from the request context +func FromTransIDContext(ctx context.Context) string { + val, ok := ctx.Value(transIDContextKey{}).(string) + if !ok { + return "" + } + return val +} diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 475b92554..c945e28e1 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -18,6 +18,7 @@ import ( "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/id" "github.com/yomorun/yomo/serverless" ) @@ -164,7 +165,7 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) { c, ok := s.sfnCallCache[reqID] s.muCallCache.Unlock() if !ok { - ylog.Error("[sfn-reducer] req_id not found", "req_id", reqID) + ylog.Error("[sfn-reducer] req_id not found", "trans_id", invoke.TransID, "req_id", reqID) return } @@ -204,7 +205,7 @@ func (s *Service) GetOverview() (*ai.OverviewResponse, error) { } // GetInvoke returns the invoke response -func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, reqID string, includeCallStack bool) (*ai.InvokeResponse, error) { +func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, transID string, includeCallStack bool) (*ai.InvokeResponse, error) { // read tools attached to the metadata tcs, err := register.ListToolCalls(s.Metadata) if err != nil { @@ -243,8 +244,8 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys "res_toolcalls", fmt.Sprintf("%+v", res.ToolCalls), "res_assistant_msgs", fmt.Sprintf("%+v", res.AssistantMessage)) - ylog.Debug(">> run function calls", "reqID", reqID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls)) - llmCalls, err := s.runFunctionCalls(res.ToolCalls, reqID) + ylog.Debug(">> run function calls", "transID", transID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls)) + llmCalls, err := s.runFunctionCalls(res.ToolCalls, transID, id.New(16)) if err != nil { return nil, err } @@ -322,7 +323,7 @@ func overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) o } // GetChatCompletions returns the llm api response -func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, reqID string, w http.ResponseWriter, includeCallStack bool) error { +func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, transID string, w http.ResponseWriter, includeCallStack bool) error { // 1. find all hosting tool sfn tagTools, err := register.ListToolCalls(s.Metadata) if err != nil { @@ -386,7 +387,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet toolCallsMap[index] = item } isFunctionCall = true - } else { + } else if streamRes.Choices[0].FinishReason != openai.FinishReasonToolCalls { _, _ = io.WriteString(w, "data: ") _ = json.NewEncoder(w).Encode(streamRes) _, _ = io.WriteString(w, "\n") @@ -436,7 +437,8 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet } } // 6. run llm function calls - llmCalls, err := s.runFunctionCalls(fnCalls, reqID) + reqID := id.New(16) + llmCalls, err := s.runFunctionCalls(fnCalls, transID, reqID) if err != nil { return err } @@ -471,9 +473,6 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet if err != nil { return err } - if len(streamRes.Choices) == 0 { - continue - } _, _ = io.WriteString(w, "data: ") _ = json.NewEncoder(w).Encode(streamRes) _, _ = io.WriteString(w, "\n") @@ -491,7 +490,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet } // run llm-sfn function calls -func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, reqID string) ([]ai.ToolMessage, error) { +func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, transID, reqID string) ([]ai.ToolMessage, error) { if len(fns) == 0 { return nil, nil } @@ -499,14 +498,15 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, reqID stri asyncCall := &sfnAsyncCall{ val: make(map[string]ai.ToolMessage), } + s.muCallCache.Lock() s.sfnCallCache[reqID] = asyncCall s.muCallCache.Unlock() for tag, tcs := range fns { - ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "reqID", reqID) + ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "transID", transID, "reqID", reqID) for _, fn := range tcs { - err := s.fireLlmSfn(tag, fn, reqID) + err := s.fireLlmSfn(tag, fn, transID, reqID) if err != nil { ylog.Error("send data to zipper", "err", err.Error()) continue @@ -533,19 +533,22 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, reqID stri } // fireLlmSfn fires the llm-sfn function call by s.source.Write() -func (s *Service) fireLlmSfn(tag uint32, fn *openai.ToolCall, reqID string) error { +func (s *Service) fireLlmSfn(tag uint32, fn *openai.ToolCall, transID, reqID string) error { ylog.Info( "+invoke func", "tag", tag, + "transID", transID, + "reqID", reqID, "toolCallID", fn.ID, "function", fn.Function.Name, "arguments", fn.Function.Arguments, - "reqID", reqID) + ) data := &ai.FunctionCall{ + TransID: transID, ReqID: reqID, ToolCallID: fn.ID, - Arguments: fn.Function.Arguments, FunctionName: fn.Function.Name, + Arguments: fn.Function.Arguments, } buf, err := data.Bytes() if err != nil {