diff --git a/ai/model.go b/ai/model.go index 69a2fe50d..4b093275f 100644 --- a/ai/model.go +++ b/ai/model.go @@ -22,17 +22,17 @@ type InvokeRequest struct { // InvokeResponse is the response for chat completions type InvokeResponse struct { // Content is the content from llm api response - Content string + Content string `json:"content,omitempty"` // ToolCalls is the toolCalls from llm api response - ToolCalls map[uint32][]*openai.ToolCall + ToolCalls map[uint32][]*openai.ToolCall `json:"tool_calls,omitempty"` // ToolMessages is the tool messages from llm api response - ToolMessages []ToolMessage + ToolMessages []ToolMessage `json:"tool_messages,omitempty"` // FinishReason is the finish reason from llm api response - FinishReason string + FinishReason string `json:"finish_reason,omitempty"` // TokenUsage is the token usage from llm api response - TokenUsage TokenUsage + TokenUsage TokenUsage `json:"token_usage,omitempty"` // AssistantMessage is the assistant message from llm api response, only present when finish reason is "tool_calls" - AssistantMessage interface{} + AssistantMessage interface{} `json:"assistant_message,omitempty"` } // TokenUsage is the token usage in Response @@ -63,7 +63,7 @@ type ParameterProperty struct { type ToolMessage struct { Role string `json:"role"` Content string `json:"content"` - ToolCallId string `json:"tool_call_id"` + ToolCallID string `json:"tool_call_id"` } // ChainMessage is the message for chaining llm request with preceeding `tool_calls` response diff --git a/ai/openai.go b/ai/openai.go index 1debeb096..24a868493 100644 --- a/ai/openai.go +++ b/ai/openai.go @@ -8,6 +8,7 @@ import ( "github.com/yomorun/yomo/core/ylog" ) +// ConvertToInvokeResponse converts openai.ChatCompletionResponse struct to InvokeResponse struct. func ConvertToInvokeResponse(res *openai.ChatCompletionResponse, tcs map[uint32]openai.Tool) (*InvokeResponse, error) { choice := res.Choices[0] ylog.Debug(">>finish_reason", "reason", choice.FinishReason) diff --git a/cli/serve.go b/cli/serve.go index 8d5dca409..3df985b79 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -27,6 +27,7 @@ import ( "github.com/yomorun/yomo/pkg/log" "github.com/yomorun/yomo/pkg/bridge/ai" + providerpkg "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure" "github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai" @@ -78,7 +79,7 @@ var serveCmd = &cobra.Command{ } if aiConfig != nil { // add AI connection middleware - options = append(options, yomo.WithZipperConnMiddleware(ai.ConnMiddleware)) + options = append(options, yomo.WithZipperConnMiddleware(ai.RegisterFunctionMW())) } // new zipper zipper, err := yomo.NewZipper( @@ -97,7 +98,7 @@ var serveCmd = &cobra.Command{ registerAIProvider(aiConfig) // start the llm api server go func() { - err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString)) + err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString), ylog.Default()) if err != nil { log.FailureStatusEvent(os.Stdout, err.Error()) return @@ -118,16 +119,16 @@ func registerAIProvider(aiConfig *ai.Config) error { for name, provider := range aiConfig.Providers { switch name { case "azopenai": - ai.RegisterProvider(azopenai.NewProvider( + providerpkg.RegisterProvider(azopenai.NewProvider( provider["api_key"], provider["api_endpoint"], provider["deployment_id"], provider["api_version"], )) case "openai": - ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) + providerpkg.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) case "cloudflare_azure": - ai.RegisterProvider(cfazure.NewProvider( + providerpkg.RegisterProvider(cfazure.NewProvider( provider["endpoint"], provider["api_key"], provider["resource"], @@ -135,20 +136,19 @@ func registerAIProvider(aiConfig *ai.Config) error { provider["api_version"], )) case "cloudflare_openai": - ai.RegisterProvider(cfopenai.NewProvider( + providerpkg.RegisterProvider(cfopenai.NewProvider( provider["endpoint"], provider["api_key"], provider["model"], )) case "ollama": - ai.RegisterProvider(ollama.NewProvider(provider["api_endpoint"])) + providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"])) default: log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name) } } - // log.InfoStatusEvent(os.Stdout, "registered [%d] AI provider", len(ai.ListProviders())) - ylog.Info("registered AI providers", "len", len(ai.ListProviders())) + ylog.Info("registered AI providers", "len", len(providerpkg.ListProviders())) return nil } diff --git a/cli/test.go b/cli/test.go index ab7ef395f..72c4e9181 100644 --- a/cli/test.go +++ b/cli/test.go @@ -214,7 +214,7 @@ var testPromptCmd = &cobra.Command{ func getToolCallResult(tc *openai.ToolCall, tms []ai.ToolMessage) string { result := "" for _, tm := range tms { - if tm.ToolCallId == tc.ID { + if tm.ToolCallID == tc.ID { result = tm.Content } } diff --git a/core/connection.go b/core/connection.go index d5e877572..2a0bbdef1 100644 --- a/core/connection.go +++ b/core/connection.go @@ -44,7 +44,8 @@ type Connection struct { Logger *slog.Logger } -func newConnection( +// NewConnection creates a new connection according to the parameters. +func NewConnection( id uint64, name string, clientID string, clientType ClientType, md metadata.M, tags []uint32, fconn frame.Conn, logger *slog.Logger, diff --git a/core/connection_test.go b/core/connection_test.go index c1ca2c321..b6e1ee66c 100644 --- a/core/connection_test.go +++ b/core/connection_test.go @@ -18,7 +18,7 @@ func TestConnection(t *testing.T) { md metadata.M ) - connection := newConnection(id, name, clientID, styp, md, observed, nil, ylog.Default()) + connection := NewConnection(id, name, clientID, styp, md, observed, nil, ylog.Default()) t.Run("ConnectionInfo", func(t *testing.T) { assert.Equal(t, id, connection.ID()) diff --git a/core/connector_test.go b/core/connector_test.go index 4dbf0606d..0a3d31e4f 100644 --- a/core/connector_test.go +++ b/core/connector_test.go @@ -117,5 +117,5 @@ func TestConnector(t *testing.T) { // mockConn returns a connection that only includes an ID and a name. // This function is used for unit testing purposes. func mockConn(id uint64, name string) *Connection { - return newConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default()) + return NewConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default()) } diff --git a/core/router/router.go b/core/router/router.go index 5fceffe2c..a41c443b1 100644 --- a/core/router/router.go +++ b/core/router/router.go @@ -33,9 +33,9 @@ type defaultRouter struct { data map[frame.Tag]map[uint64]struct{} } -// DefaultRouter provides a default implementation of `router`, +// Default provides a default implementation of `router`, // It routes data according to observed tag and metadata. -func Default() *defaultRouter { +func Default() Router { return &defaultRouter{ targets: make(map[uint64]string), data: make(map[frame.Tag]map[uint64]struct{}), diff --git a/core/server.go b/core/server.go index 79b54586e..6e1d722f6 100644 --- a/core/server.go +++ b/core/server.go @@ -293,7 +293,7 @@ func (s *Server) createConnection(hf *frame.HandshakeFrame, md metadata.M, fconn if hf.WantedTarget != "" { md.Set(metadata.WantedTargetKey, hf.WantedTarget) } - conn := newConnection( + conn := NewConnection( incrID(), hf.Name, hf.ID, diff --git a/options.go b/options.go index 5bea59a79..245b51551 100644 --- a/options.go +++ b/options.go @@ -112,6 +112,7 @@ var ( } } + // WithVersionNegotiateFunc sets the version negotiate function for the zipper WithVersionNegotiateFunc = func(f core.VersionNegotiateFunc) ZipperOption { return func(zo *zipperOptions) { zo.serverOption = append(zo.serverOption, core.WithVersionNegotiateFunc(f)) diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index 94fadaac5..1e253baf7 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -22,43 +22,48 @@ var ( ErrConfigFormatError = errors.New("ai config format is incorrect") ) -// ConnMiddleware returns a ConnMiddleware that can be used to intercept the connection. -func ConnMiddleware(next core.ConnHandler) core.ConnHandler { - return func(conn *core.Connection) { - connMd := conn.Metadata().Clone() - definition, ok := connMd.Get(ai.FunctionDefinitionKey) +// RegisterFunctionMW returns a ConnMiddleware that can be used to register an ai function. +func RegisterFunctionMW() core.ConnMiddleware { return registerFunction(register.GetRegister()) } - defer func() { - // definition does not be transmitted in mesh network, It only works for handshake. - conn.Metadata().Set(ai.FunctionDefinitionKey, "") - next(conn) - if ok { - register.UnregisterFunction(conn.ID(), connMd) - conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID()) - } - }() +func registerFunction(r register.Register) core.ConnMiddleware { + return core.ConnMiddleware(func(next core.ConnHandler) core.ConnHandler { + return func(conn *core.Connection) { + connMd := conn.Metadata().Clone() + definition, ok := connMd.Get(ai.FunctionDefinitionKey) - if conn.ClientType() != core.ClientTypeStreamFunction || !ok { - return - } + defer func() { + if ok { + conn.Metadata().Set(ai.FunctionDefinitionKey, "") + } + // definition does not be transmitted in mesh network, It only works for handshake. + next(conn) + if ok { + register.UnregisterFunction(conn.ID(), connMd) + conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID()) + } + }() - for _, tag := range conn.ObserveDataTags() { - // register ai function - fd := ai.FunctionDefinition{} - err := json.Unmarshal([]byte(definition), &fd) - if err != nil { - conn.Logger.Error("unmarshal function definition", "error", err) + if conn.ClientType() != core.ClientTypeStreamFunction || !ok { return } - err = register.RegisterFunction(tag, &fd, conn.ID(), connMd) - if err != nil { - conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err) - return + + for _, tag := range conn.ObserveDataTags() { + // register ai function + fd := ai.FunctionDefinition{} + err := json.Unmarshal([]byte(definition), &fd) + if err != nil { + conn.Logger.Error("unmarshal function definition", "error", err) + return + } + err = r.RegisterFunction(tag, &fd, conn.ID(), connMd) + if err != nil { + conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err) + return + } + conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition)) } - conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition)) } - - } + }) } // Config is the configuration of AI bridge. diff --git a/pkg/bridge/ai/ai_test.go b/pkg/bridge/ai/ai_test.go index 451543c6d..e2c809962 100644 --- a/pkg/bridge/ai/ai_test.go +++ b/pkg/bridge/ai/ai_test.go @@ -1,19 +1,45 @@ package ai import ( - "context" - "net/http" - "net/http/httptest" - "sync" + "fmt" "testing" openai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/core/ylog" "github.com/yomorun/yomo/pkg/bridge/ai/register" ) +func TestRegisterFunction(t *testing.T) { + r := register.NewDefault() + connHandler := registerFunction(r)(func(c *core.Connection) {}) + + t.Run("source", func(t *testing.T) { + conn := mockSourceConn(1, "source") + connHandler(conn) + + toolCalls, _ := r.ListToolCalls(conn.Metadata()) + assert.Equal(t, map[uint32]openai.Tool{}, toolCalls) + }) + + t.Run("stream function", func(t *testing.T) { + conn := mockSfnConn(2, "sfn") + connHandler(conn) + + toolCalls, _ := r.ListToolCalls(conn.Metadata()) + + want := map[uint32]openai.Tool{ + 0x33: {Type: "function", Function: &openai.FunctionDefinition{Name: "sfn"}}, + } + + assert.Equal(t, want, toolCalls) + }) +} + func TestParseZipperAddr(t *testing.T) { tests := []struct { name string @@ -129,211 +155,13 @@ func TestParseConfig(t *testing.T) { } } -type MockLLMProvider struct { - name string -} - -func (m *MockLLMProvider) GetChatCompletions(_ context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { - return openai.ChatCompletionResponse{}, nil -} - -func (m *MockLLMProvider) GetChatCompletionsStream(_ context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ResponseRecver, error) { - return nil, nil -} - -var _ LLMProvider = &MockLLMProvider{} - -func (m *MockLLMProvider) Name() string { - return m.name -} - -func TestListProviders(t *testing.T) { - t.Cleanup(func() { - providers = sync.Map{} - defaultProvider = nil - }) - providers.Store("provider1", &MockLLMProvider{}) - providers.Store("provider2", &MockLLMProvider{}) - providers.Store("provider3", &MockLLMProvider{}) - - expected := []string{"provider1", "provider2", "provider3"} // Replace with the expected provider names - - val := ListProviders() - - assert.ElementsMatch(t, expected, val) -} - -func TestRegisterProvider(t *testing.T) { - t.Cleanup(func() { - providers = sync.Map{} - defaultProvider = nil - }) - provider := &MockLLMProvider{name: "testProvider"} - - // Test registering a provider - RegisterProvider(provider) - - // Check if the provider was registered - storedProvider, ok := providers.Load(provider.Name()) - assert.True(t, ok) - assert.Equal(t, provider, storedProvider) -} - -func TestSetDefaultProvider(t *testing.T) { - t.Cleanup(func() { - providers = sync.Map{} - defaultProvider = nil - }) - provider := &MockLLMProvider{name: "testProvider"} - - // Register a provider - RegisterProvider(provider) - - // Set the provider as default - SetDefaultProvider(provider.Name()) - - // Check if the default provider was set - assert.Equal(t, provider, defaultProvider) -} - -func TestGetProviderAndSetDefault(t *testing.T) { - t.Cleanup(func() { - providers = sync.Map{} - defaultProvider = nil - }) - provider := &MockLLMProvider{name: "testProvider"} - - // Register a provider - RegisterProvider(provider) - - // Get the provider and set it as default - retrievedProvider, err := GetProviderAndSetDefault(provider.Name()) - - // Check if the correct provider was retrieved and set as default - assert.NoError(t, err) - assert.Equal(t, provider, retrievedProvider) - assert.Equal(t, provider, defaultProvider) -} - -func TestGetDefaultProvider(t *testing.T) { - t.Cleanup(func() { - providers = sync.Map{} - defaultProvider = nil - }) - - provider1 := &MockLLMProvider{name: "provider1"} - provider2 := &MockLLMProvider{name: "provider2"} - - // Register first provider - RegisterProvider(provider1) - - // Test getting the default provider when none is set - // The first available provider should be returned - p, err := GetDefaultProvider() - assert.NoError(t, err) - assert.Equal(t, provider1, p) - - // Register second provider - RegisterProvider(provider2) - - // Set the second provider as default - SetDefaultProvider(provider2.Name()) - - // Test getting the default provider when one is set - // The default provider should be returned - p, err = GetDefaultProvider() - assert.NoError(t, err) - assert.Equal(t, provider2, p) -} - -func TestServiceContext(t *testing.T) { - // Create a new service - service := &Service{} - - // Create a new context - ctx := context.Background() - - // Add the service to the context - ctx = WithServiceContext(ctx, service) - - // Retrieve the service from the context - retrievedService := FromServiceContext(ctx) - - // Check if the correct service was retrieved - assert.Equal(t, service, retrievedService) - - // Test with a context that does not contain a service - ctx = context.Background() - retrievedService = FromServiceContext(ctx) - - // Check if no service was retrieved - assert.Nil(t, retrievedService) -} - -func TestHandleOverview(t *testing.T) { - functionDefinition := &ai.FunctionDefinition{ - Name: "function1", - Description: "desc1", - Parameters: &ai.FunctionParameters{ - Type: "type1", - Properties: map[string]*ai.ParameterProperty{ - "prop1": {Type: "type1", Description: "desc1"}, - "prop2": {Type: "type2", Description: "desc2"}, - }, - Required: []string{"prop1"}, - }, +func mockSfnConn(id uint64, name string) *core.Connection { + md := metadata.M{ + ai.FunctionDefinitionKey: fmt.Sprintf(`{"name": "%s"}`, name), } - r := register.GetRegister() - r.RegisterFunction(100, functionDefinition, 200, nil) - - register.SetRegister(r) - - // Create a new mock service - service := &Service{ - LLMProvider: &MockLLMProvider{}, - } - - // Create a new request - req, err := http.NewRequest("GET", "/overview", nil) - assert.NoError(t, err) - - // Add the service to the request context - req = req.WithContext(WithServiceContext(req.Context(), service)) - - // Record the response - rr := httptest.NewRecorder() - - // Create a handler function - handler := http.HandlerFunc(HandleOverview) - - // Serve the request - handler.ServeHTTP(rr, req) - - // Check the response status code - assert.Equal(t, http.StatusOK, rr.Code) - - // Check the response body - // This is a basic check for an empty body, replace with your own logic - assert.Equal(t, "{\"Functions\":{\"100\":{\"name\":\"function1\",\"description\":\"desc1\",\"parameters\":{\"type\":\"type1\",\"properties\":{\"prop1\":{\"type\":\"type1\",\"description\":\"desc1\"},\"prop2\":{\"type\":\"type2\",\"description\":\"desc2\"}},\"required\":[\"prop1\"]}}}}\n", rr.Body.String()) + return core.NewConnection(id, name, "mock-sfn-id", core.ClientTypeStreamFunction, md, []frame.Tag{0x33}, nil, ylog.Default()) } -func TestNewBasicAPIServer(t *testing.T) { - // Create a new mock provider - provider := &MockLLMProvider{name: "testProvider"} - - // Create a new config - config := &Config{} - - // Call the NewBasicAPIServer function - server, err := NewBasicAPIServer("testServer", config, "localhost:8080", provider, "testCredential") - - // Check if no error was returned - assert.NoError(t, err) - - // Check if the server was correctly created - assert.Equal(t, "testServer", server.Name) - assert.Equal(t, config, server.Config) - assert.Equal(t, "localhost:8080", server.ZipperAddr) - assert.Equal(t, provider, server.Provider) - assert.Equal(t, "testCredential", server.serviceCredential) +func mockSourceConn(id uint64, name string) *core.Connection { + return core.NewConnection(id, name, "mock-source-id", core.ClientTypeSource, metadata.New(), []frame.Tag{0x33}, nil, ylog.Default()) } diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index ac038aba7..1776efa55 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -5,13 +5,15 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net" "net/http" "time" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/ai" - "github.com/yomorun/yomo/core/ylog" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/bridge/ai/register" "github.com/yomorun/yomo/pkg/id" ) @@ -29,189 +31,178 @@ var ( // BasicAPIServer provides restful service for end user type BasicAPIServer struct { - // Name is the name of the server - Name string - // Config is the configuration of the server - *Config - // ZipperAddr is the address of the zipper - ZipperAddr string - // Provider is the llm provider - Provider LLMProvider - // serviceCredential is the credential for Function Calling Service - serviceCredential string + zipperAddr string + credential string + httpHandler http.Handler + logger *slog.Logger } // Serve starts the Basic API Server -func Serve(config *Config, zipperListenAddr string, credential string) error { - provider, err := GetProviderAndSetDefault(config.Server.Provider) +func Serve(config *Config, zipperListenAddr string, credential string, logger *slog.Logger) error { + provider, err := provider.GetProvider(config.Server.Provider) if err != nil { return err } - srv, err := NewBasicAPIServer(provider.Name(), config, zipperListenAddr, provider, credential) + srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger) if err != nil { return err } - return srv.Serve() -} -// NewBasicAPIServer creates a new restful service -func NewBasicAPIServer(name string, config *Config, zipperAddr string, provider LLMProvider, credential string) (*BasicAPIServer, error) { - zipperAddr = parseZipperAddr(zipperAddr) - return &BasicAPIServer{ - Name: name, - Config: config, - ZipperAddr: zipperAddr, - Provider: provider, - serviceCredential: credential, - }, nil + logger.Info("start bridge server", "addr", config.Server.Addr, "provider", provider.Name()) + return srv.ServeAddr(config.Server.Addr) } -// Serve starts a RESTful service that provides a '/invoke' endpoint. -// Users submit questions to this endpoint. The service then generates a prompt based on the question and -// registered functions. It calls the completion api by llm provider to get the functions and arguments to be -// invoked. These functions are invoked sequentially by YoMo. all the functions write their results to the -// reducer-sfn. -func (a *BasicAPIServer) Serve() error { +func BridgeHTTPHanlder(decorater func(http.Handler) http.Handler) http.Handler { mux := http.NewServeMux() // GET /overview mux.HandleFunc("/overview", HandleOverview) // POST /invoke mux.HandleFunc("/invoke", HandleInvoke) - // POST /v1/chat/completions OpenAI compatible interface + // POST /v1/chat/completions (OpenAI compatible interface) mux.HandleFunc("/v1/chat/completions", HandleChatCompletions) - handler := WithContextService(mux, a.serviceCredential, a.ZipperAddr, a.Provider, DefaultExchangeMetadataFunc) + return decorater(mux) +} + +// NewBasicAPIServer creates a new restful service +func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) { + zipperAddr = parseZipperAddr(zipperAddr) - addr := a.Config.Server.Addr - ylog.Info("server is running", "addr", addr, "ai_provider", a.Name) - return http.ListenAndServe(addr, handler) + cp := NewCallerProvider(zipperAddr, provider, DefaultExchangeMetadataFunc) + + server := &BasicAPIServer{ + zipperAddr: zipperAddr, + credential: credential, + httpHandler: BridgeHTTPHanlder(decorateReqContext(cp, logger, credential)), + logger: logger.With("component", "bridge"), + } + + return server, nil } -// WithContextService adds the service to the request context -func WithContextService(handler http.Handler, credential string, zipperAddr string, provider LLMProvider, exFn ExchangeMetadataFunc) http.Handler { - // create service instance when the api server starts - service, err := LoadOrCreateService(credential, zipperAddr, provider, exFn) - if err != nil { +// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo. +// User can chat to the http server and interact with the YoMo's stream function. +func (a *BasicAPIServer) ServeAddr(addr string) error { + return http.ListenAndServe(addr, a.httpHandler) +} + +// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context. +func decorateReqContext(cp *CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + caller, err := cp.Provide(credential) + if err != nil { + logger.Info("can't load caller", "err", err) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + RespondWithError(w, http.StatusInternalServerError, err) + }) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + transID := id.New(32) + ctx := WithTransIDContext(r.Context(), transID) + ctx = WithCallerContext(ctx, caller) + + logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID) + + handler.ServeHTTP(w, r.WithContext(ctx)) }) } - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - transID := id.New(32) - ctx := WithTransIDContext(r.Context(), transID) - ctx = WithServiceContext(ctx, service) - handler.ServeHTTP(w, r.WithContext(ctx)) - }) } // HandleOverview is the handler for GET /overview func HandleOverview(w http.ResponseWriter, r *http.Request) { - service := FromServiceContext(r.Context()) + caller := FromCallerContext(r.Context()) w.Header().Set("Content-Type", "application/json") - // credential := getBearerToken(r) - resp, err := service.GetOverview() + + tcs, err := register.ListToolCalls(caller.Metadata()) if err != nil { w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) return } + + functions := make(map[uint32]*openai.FunctionDefinition) + for tag, tc := range tcs { + functions[tag] = tc.Function + } + w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(resp) + json.NewEncoder(w).Encode(&ai.OverviewResponse{Functions: functions}) } +var baseSystemMessage = `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.` + // HandleInvoke is the handler for POST /invoke func HandleInvoke(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - service = FromServiceContext(ctx) + caller = FromCallerContext(ctx) transID = FromTransIDContext(ctx) ) defer r.Body.Close() - var req ai.InvokeRequest - - // decode the request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - ylog.Error("decode request", "err", err.Error()) - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + req, err := DecodeRequest[ai.InvokeRequest](r, w) + if err != nil { return } - w.Header().Set("Content-Type", "application/json") - - // Create a context with a timeout of 5 seconds - ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) defer cancel() - // messages - baseSystemMessage := `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.` - - // Make the service call in a separate goroutine, and use a channel to get the result - resCh := make(chan *ai.InvokeResponse, 1) - errCh := make(chan error, 1) - go func(service *Service, req ai.InvokeRequest, baseSystemMessage string) { - // call llm to infer the function and arguments to be invoked - ylog.Debug(">> ai request", "transID", transID, "prompt", req.Prompt) - res, err := service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack) - if err != nil { - errCh <- err - } else { - resCh <- res - } - }(service, req, baseSystemMessage) - - // Use a select statement to handle the result or timeout - select { - case res := <-resCh: - ylog.Debug(">> ai response response", "res", fmt.Sprintf("%+v", res)) - // write the response to the client with res - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(res) - case err := <-errCh: - ylog.Error("invoke service", "err", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - case <-ctx.Done(): - // The context was cancelled, which means the service call timed out - w.WriteHeader(http.StatusRequestTimeout) - json.NewEncoder(w).Encode(map[string]string{"error": "request timed out"}) + res, err := caller.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack) + if err != nil { + w.Header().Set("Content-Type", "application/json") + RespondWithError(w, http.StatusInternalServerError, err) + return } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(res) } // HandleChatCompletions is the handler for POST /chat/completion func HandleChatCompletions(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - service = FromServiceContext(ctx) + caller = FromCallerContext(ctx) transID = FromTransIDContext(ctx) ) defer r.Body.Close() - var req openai.ChatCompletionRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - ylog.Error("decode request", "err", err.Error()) - RespondWithError(w, http.StatusBadRequest, err) + req, err := DecodeRequest[openai.ChatCompletionRequest](r, w) + if err != nil { return } ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) defer cancel() - if err := service.GetChatCompletions(ctx, req, transID, w, false); err != nil { - ylog.Error("invoke chat completions", "err", err.Error()) + if err := caller.GetChatCompletions(ctx, req, transID, w); err != nil { RespondWithError(w, http.StatusBadRequest, err) return } } +// DecodeRequest decodes the request body into given type. +func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) { + var req T + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + w.Header().Set("Content-Type", "application/json") + RespondWithError(w, http.StatusBadRequest, err) + return req, err + } + + return req, nil +} + // RespondWithError writes an error to response according to the OpenAI API spec. func RespondWithError(w http.ResponseWriter, code int, err error) { w.WriteHeader(code) - w.Header().Set("Content-Type", "application/json") w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, err.Error()))) } @@ -231,16 +222,16 @@ func getLocalIP() (string, error) { return "", errors.New("not found local ip") } -type serviceContextKey struct{} +type callerContextKey struct{} -// WithServiceContext adds the service to the request context -func WithServiceContext(ctx context.Context, service *Service) context.Context { - return context.WithValue(ctx, serviceContextKey{}, service) +// WithCallerContext adds the caller to the request context +func WithCallerContext(ctx context.Context, caller *Caller) context.Context { + return context.WithValue(ctx, callerContextKey{}, caller) } -// FromServiceContext returns the service from the request context -func FromServiceContext(ctx context.Context) *Service { - service, ok := ctx.Value(serviceContextKey{}).(*Service) +// FromCallerContext returns the caller from the request context +func FromCallerContext(ctx context.Context) *Caller { + service, ok := ctx.Value(callerContextKey{}).(*Caller) if !ok { return nil } diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/ai/api_server_test.go new file mode 100644 index 000000000..901a46aed --- /dev/null +++ b/pkg/bridge/ai/api_server_test.go @@ -0,0 +1,116 @@ +package ai + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/bridge/ai/register" +) + +func TestServer(t *testing.T) { + // register a function definition to the register + functionDefinition := &ai.FunctionDefinition{ + Name: "function1", + Description: "desc1", + Parameters: &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + register.SetRegister(register.NewDefault()) + register.RegisterFunction(100, functionDefinition, 200, nil) + + // mock the provider and the req/res of the caller + pd, err := provider.NewMock("mock provider", provider.MockChatCompletionResponse(stopResp, stopResp)) + if err != nil { + t.Fatal(err) + } + + cp := newMockCallerProvider() + + cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{}, pd) + + handler := BridgeHTTPHanlder(decorateReqContext(cp, slog.Default(), "")) + + // create a test server + server := httptest.NewServer(handler) + + httpClient := server.Client() + + t.Run("GET /overview", func(t *testing.T) { + url := fmt.Sprintf("%s/overview", server.URL) + + resp, err := httpClient.Get(url) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, `{"Functions":{"100":{"name":"function1","description":"desc1","parameters":{"type":"type1","properties":{"prop1":{"type":"type1","description":"desc1"},"prop2":{"type":"type2","description":"desc2"}},"required":["prop1"]}}}} +`, string(body)) + }) + + t.Run("POST /invoke", func(t *testing.T) { + url := fmt.Sprintf("%s/invoke", server.URL) + + resp, err := httpClient.Post(url, "application/json", bytes.NewBufferString(`{"prompt": "Hi, How are you"}`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "{\"content\":\"Hello! I'm just a computer program, so I don't have feelings, but thanks for asking. How can I assist you today?\",\"finish_reason\":\"stop\",\"token_usage\":{\"prompt_tokens\":13,\"completion_tokens\":26}}\n", string(body)) + }) + + t.Run("POST /v1/chat/completions", func(t *testing.T) { + url := fmt.Sprintf("%s/v1/chat/completions", server.URL) + + resp, err := httpClient.Post(url, "application/json", bytes.NewBufferString(`{"messages":[{"role":"user","content":"Hi, How are you"}]}`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "{\"id\":\"chatcmpl-9blYknv9rHvr2dvCQKMeW21hlBpCX\",\"object\":\"chat.completion\",\"created\":1718787982,\"model\":\"gpt-4o-2024-05-13\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello! I'm just a computer program, so I don't have feelings, but thanks for asking. How can I assist you today?\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":13,\"completion_tokens\":26,\"total_tokens\":39},\"system_fingerprint\":\"fp_f4e629d0a5\"}\n", string(body)) + }) + + t.Run("illegal request", func(t *testing.T) { + url := fmt.Sprintf("%s/v1/chat/completions", server.URL) + + resp, err := httpClient.Post(url, "application/json", bytes.NewBufferString(`some illegal request`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "{\"error\":{\"code\":\"400\",\"message\":\"invalid character 's' looking for beginning of value\"}}", string(body)) + }) +} diff --git a/pkg/bridge/ai/call_syncer.go b/pkg/bridge/ai/call_syncer.go new file mode 100644 index 000000000..c2ae78b51 --- /dev/null +++ b/pkg/bridge/ai/call_syncer.go @@ -0,0 +1,248 @@ +package ai + +import ( + "context" + "log/slog" + "time" + + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/serverless" +) + +// CallSyncer fires a bunch of function callings, and wait the result of these function callings. +// every tool call has a toolCallID, which is used to identify the function calling, +// Note that one tool call can only be responded once. +type CallSyncer interface { + // Call fires a bunch of function callings, and wait the result of these function callings. + // The result only contains the messages with role=="tool". + // If some function callings failed, the content will be returned as the failed reason. + Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) + // Close close the CallSyncer. if close, you can't use this CallSyncer anymore. + Close() error +} + +type callSyncer struct { + ctx context.Context + cancel context.CancelFunc + logger *slog.Logger + + // timeout is the timeout for waiting the result. + timeout time.Duration + source yomo.Source + reducer yomo.StreamFunction + reduceCh <-chan reduceMessage + toolOutCh chan toolOut + cleanCh chan string +} + +type reduceMessage struct { + reqID string + message openai.ChatCompletionMessage +} + +// NewCallSyncer creates a new CallSyncer. +func NewCallSyncer(logger *slog.Logger, source yomo.Source, reducer yomo.StreamFunction, timeout time.Duration) CallSyncer { + if timeout == 0 { + timeout = RunFunctionTimeout + } + ctx, cancel := context.WithCancel(context.Background()) + + syncer := &callSyncer{ + ctx: ctx, + cancel: cancel, + logger: logger, + timeout: timeout, + source: source, + reducer: reducer, + reduceCh: handleToChan(logger, reducer), + toolOutCh: make(chan toolOut), + cleanCh: make(chan string), + } + + go syncer.background() + + return syncer +} + +type toolOut struct { + reqID string + toolIDs map[string]struct{} + ch chan openai.ChatCompletionMessage +} + +func (f *callSyncer) Call(ctx context.Context, transID, reqID string, tagToolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) { + defer func() { + f.cleanCh <- reqID + }() + + toolIDs, err := f.fire(transID, reqID, tagToolCalls) + if err != nil { + return nil, err + } + ch := make(chan openai.ChatCompletionMessage) + + otherToolIDs := make(map[string]struct{}) + for id := range toolIDs { + otherToolIDs[id] = struct{}{} + } + + toolOut := toolOut{ + reqID: reqID, + toolIDs: otherToolIDs, + ch: ch, + } + + f.toolOutCh <- toolOut + + var result []openai.ChatCompletionMessage + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case res := <-ch: + result = append(result, res) + + delete(toolIDs, res.ToolCallID) + if len(toolIDs) == 0 { + return result, nil + } + case <-time.After(f.timeout): + for id := range toolIDs { + result = append(result, openai.ChatCompletionMessage{ + ToolCallID: id, + Role: openai.ChatMessageRoleTool, + Content: "timeout in this function calling, you should ignore this.", + }) + } + return result, nil + } + } +} + +func (f *callSyncer) fire(transID string, reqID string, tagToolCalls map[uint32][]*openai.ToolCall) (map[string]struct{}, error) { + ToolIDs := make(map[string]struct{}) + + for tag, tools := range tagToolCalls { + f.logger.Debug("fire tool_calls", "tag", tag, "len(tools)", len(tools), "transID", transID, "reqID", reqID) + + for _, t := range tools { + data := &ai.FunctionCall{ + TransID: transID, + ReqID: reqID, + ToolCallID: t.ID, + FunctionName: t.Function.Name, + Arguments: t.Function.Arguments, + } + buf, _ := data.Bytes() + + if err := f.source.Write(tag, buf); err != nil { + // TODO: maybe we should make a send failed collection here. + f.logger.Error("send data to zipper", "err", err.Error()) + continue + } + ToolIDs[t.ID] = struct{}{} + } + } + + return ToolIDs, nil +} + +// Close close the CallSyncer. if close, you can't use this CallSyncer anymore. +func (f *callSyncer) Close() error { + f.cancel() + + var err error + if err = f.source.Close(); err != nil { + f.logger.Error("callSyncer writer close", "err", err.Error()) + } + + if err = f.reducer.Close(); err != nil { + f.logger.Error("callSyncer reducer close", "err", err.Error()) + } + + return err +} + +func (f *callSyncer) background() { + // buffered stores the messages from the reducer, the key is the reqID + buffered := make(map[string]map[string]openai.ChatCompletionMessage) + // singnals stores the result channel, the key is the reqID, the value channel will be sent when the buffered is fulled. + singnals := make(map[string]toolOut) + + for { + select { + case <-f.ctx.Done(): + return + + case out := <-f.toolOutCh: + singnals[out.reqID] = out + + // send data buffered to the result channel, one ToolCallID has one result. + for _, msg := range buffered[out.reqID] { + if _, ok := out.toolIDs[msg.ToolCallID]; !ok { + continue + } + out.ch <- msg + delete(buffered[out.reqID], msg.ToolCallID) + delete(singnals[out.reqID].toolIDs, msg.ToolCallID) + } + + case reqID := <-f.cleanCh: + delete(buffered, reqID) + delete(singnals, reqID) + + case msg := <-f.reduceCh: + if msg.reqID == "" { + f.logger.Warn("recv unexpected message", "msg", msg) + continue + } + result := openai.ChatCompletionMessage{ + ToolCallID: msg.message.ToolCallID, + Role: msg.message.Role, + Content: msg.message.Content, + } + + sig, ok := singnals[msg.reqID] + // the signal that requests a result has not been sent. so buffer the data from reducer. + if !ok { + _, ok := buffered[msg.reqID] + if !ok { + buffered[msg.reqID] = make(map[string]openai.ChatCompletionMessage) + } + buffered[msg.reqID][msg.message.ToolCallID] = result + } else { + // the signal was sent, + // check if the message has been sent, and if not, send the message to signal's channel. + if _, ok := sig.toolIDs[msg.message.ToolCallID]; ok { + sig.ch <- result + } + } + } + } +} + +func handleToChan(logger *slog.Logger, reducer yomo.StreamFunction) <-chan reduceMessage { + ch := make(chan reduceMessage) + + reducer.SetHandler(func(ctx serverless.Context) { + invoke, err := ctx.LLMFunctionCall() + if err != nil { + ch <- reduceMessage{reqID: ""} + logger.Error("parse function calling invoke", "err", err.Error()) + return + } + logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result)) + + message := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: invoke.Result, + ToolCallID: invoke.ToolCallID, + } + + ch <- reduceMessage{reqID: invoke.ReqID, message: message} + }) + + return ch +} diff --git a/pkg/bridge/ai/call_syncer_test.go b/pkg/bridge/ai/call_syncer_test.go new file mode 100644 index 000000000..a06ea9fc4 --- /dev/null +++ b/pkg/bridge/ai/call_syncer_test.go @@ -0,0 +1,160 @@ +package ai + +import ( + "context" + "log/slog" + "sync" + "testing" + "time" + + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/serverless" + "github.com/yomorun/yomo/serverless/mock" +) + +var testdata = map[uint32][]*openai.ToolCall{ + 1: {{ID: "tool-call-id-1", Function: openai.FunctionCall{Name: "function-1"}}}, + 2: {{ID: "tool-call-id-2", Function: openai.FunctionCall{Name: "function-2"}}}, + 3: {{ID: "tool-call-id-3", Function: openai.FunctionCall{Name: "function-3"}}}, + 4: {{ID: "tool-call-id-4", Function: openai.FunctionCall{Name: "function-4"}}}, +} + +func TestTimeoutCallSyncer(t *testing.T) { + h := newHandler(2 * time.Hour) // h.sleep > syncer.timeout + flow := newMockDataFlow(h.handle) + defer flow.Close() + + syncer := NewCallSyncer(slog.Default(), flow, flow, time.Millisecond) + defer syncer.Close() + go flow.run() + + var ( + transID = "mock-trans-id" + reqID = "mock-req-id" + ) + + want := []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleTool, + ToolCallID: "tool-call-id", + Content: "timeout in this function calling, you should ignore this.", + }, + } + + got, _ := syncer.Call(context.TODO(), transID, reqID, map[uint32][]*openai.ToolCall{ + 1: {{ID: "tool-call-id", Function: openai.FunctionCall{Name: "timeout-function"}}}, + }) + + assert.ElementsMatch(t, want, got) +} + +func TestCallSyncer(t *testing.T) { + h := newHandler(0) + flow := newMockDataFlow(h.handle) + defer flow.Close() + + syncer := NewCallSyncer(slog.Default(), flow, flow, 0) + defer syncer.Close() + go flow.run() + + var ( + transID = "mock-trans-id" + reqID = "mock-req-id" + ) + + got, _ := syncer.Call(context.TODO(), transID, reqID, testdata) + + assert.NotEmpty(t, got) + assert.ElementsMatch(t, h.result(), got) +} + +// handler.handle implements core.AsyncHandler, it just echo the context be written. +type handler struct { + sleep time.Duration + mu sync.Mutex + ctxs map[*mock.MockContext]struct{} +} + +func newHandler(sleep time.Duration) *handler { + return &handler{ + sleep: sleep, + ctxs: make(map[*mock.MockContext]struct{}), + } +} + +func (h *handler) handle(c serverless.Context) { + time.Sleep(h.sleep) + + h.mu.Lock() + defer h.mu.Unlock() + h.ctxs[c.(*mock.MockContext)] = struct{}{} +} + +func (h *handler) result() []openai.ChatCompletionMessage { + h.mu.Lock() + defer h.mu.Unlock() + + want := []openai.ChatCompletionMessage{} + for c := range h.ctxs { + invoke, _ := c.LLMFunctionCall() + want = append(want, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, Content: invoke.Result, ToolCallID: invoke.ToolCallID, + }) + } + + return want +} + +// mockDataFlow mocks the data flow of ai bridge. +// The data flow is: source -> hander -> reducer, +// It is `Write() -> handler() -> reducer()` in this mock implementation. +type mockDataFlow struct { + wrCh chan *mock.MockContext + reducer core.AsyncHandler + handler core.AsyncHandler +} + +func newMockDataFlow(handler core.AsyncHandler) *mockDataFlow { + return &mockDataFlow{ + wrCh: make(chan *mock.MockContext), + handler: handler, + } +} + +func (t *mockDataFlow) Write(tag uint32, data []byte) error { + t.wrCh <- mock.NewMockContext(data, tag) + return nil +} + +func (t *mockDataFlow) SetHandler(fn core.AsyncHandler) error { + t.reducer = fn + return nil +} + +func (t *mockDataFlow) Close() error { return nil } + +// this function explains how the data flow works, +// it receives data from the write channel, and handle with the handler, then send the result to the reducer. +func (t *mockDataFlow) run() { + for c := range t.wrCh { + t.handler(c) + t.reducer(c) + } +} + +var _ yomo.Source = (*mockDataFlow)(nil) +var _ yomo.StreamFunction = (*mockDataFlow)(nil) + +// The test will not use blowing function in this mock implementation. +func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") } +func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") } +func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) { panic("unimplemented") } +func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") } +func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") } +func (t *mockDataFlow) Wait() { panic("unimplemented") } +func (t *mockDataFlow) Connect() error { panic("unimplemented") } +func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") } +func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/caller.go similarity index 53% rename from pkg/bridge/ai/service.go rename to pkg/bridge/ai/caller.go index 05bd1d4ed..5350d2bb2 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/caller.go @@ -5,209 +5,143 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "strings" - "sync" "sync/atomic" "time" "github.com/hashicorp/golang-lru/v2/expirable" - "github.com/sashabaranov/go-openai" + openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/bridge/ai/register" "github.com/yomorun/yomo/pkg/id" - "github.com/yomorun/yomo/serverless" ) var ( - // ServiceCacheSize is the size of the service cache - ServiceCacheSize = 1024 - // ServiceCacheTTL is the time to live of the service cache - ServiceCacheTTL = time.Minute * 0 // 30 - // services is the cache of Service - services *expirable.LRU[string, *Service] + // CallerProviderCacheSize is the size of the caller provider cache + CallerProviderCacheSize = 1024 + // CallerProviderCacheTTL is the time to live of the provider cache + CallerProviderCacheTTL = time.Minute * 0 ) -// Service is used to invoke LLM Provider to get the functions to be executed, -// then, use source to send arguments which returned by llm provider to target -// function. Finally, use reducer to aggregate all the results, and write the -// result by the http.ResponseWriter. -type Service struct { - credential string - zipperAddr string - Metadata metadata.M - systemPrompt atomic.Value - source yomo.Source - reducer yomo.StreamFunction - sfnCallCache map[string]*sfnAsyncCall - muCallCache sync.Mutex - LLMProvider +// CallerProvider provides the caller, which is used to interact with YoMo's stream function. +type CallerProvider struct { + lp provider.LLMProvider + zipperAddr string + exFn ExchangeMetadataFunc + provideFunc provideFunc + callers *expirable.LRU[string, *Caller] } -// LoadOrCreateService loads or creates a new AI service, if the service is already created, it will return the existing one -func LoadOrCreateService(credential string, zipperAddr string, aiProvider LLMProvider, exFn ExchangeMetadataFunc) (*Service, error) { - s, ok := services.Get(credential) - if ok { - return s, nil - } - s, err := newService(credential, zipperAddr, aiProvider, exFn) - if err != nil { - return nil, err - } - services.Add(credential, s) - return s, nil -} - -// ExchangeMetadataFunc is used to exchange metadata -type ExchangeMetadataFunc func(credential string) (metadata.M, error) +type provideFunc func(string, string, provider.LLMProvider, ExchangeMetadataFunc) (*Caller, error) -// DefaultExchangeMetadataFunc is the default ExchangeMetadataFunc, It returns an empty metadata. -func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) { - return metadata.M{}, nil +// NewCallerProvider returns a new caller provider. +func NewCallerProvider(zipperAddr string, lp provider.LLMProvider, exFn ExchangeMetadataFunc) *CallerProvider { + return newCallerProvider(zipperAddr, lp, exFn, NewCaller) } -func newService(credential string, zipperAddr string, aiProvider LLMProvider, exFn ExchangeMetadataFunc) (*Service, error) { - s := &Service{ - credential: credential, - zipperAddr: zipperAddr, - LLMProvider: aiProvider, - sfnCallCache: make(map[string]*sfnAsyncCall), +func newCallerProvider(zipperAddr string, lp provider.LLMProvider, exFn ExchangeMetadataFunc, provideFunc provideFunc) *CallerProvider { + p := &CallerProvider{ + zipperAddr: zipperAddr, + lp: lp, + exFn: exFn, + provideFunc: provideFunc, + callers: expirable.NewLRU(CallerProviderCacheSize, func(_ string, caller *Caller) { caller.Close() }, CallerProviderCacheTTL), } - s.SetSystemPrompt("") + return p +} - // metadata - if exFn == nil { - s.Metadata = metadata.M{} - } else { - md, err := exFn(credential) - if err != nil { - ylog.Error("exchange metadata failed", "err", err) - return nil, err - } - s.Metadata = md +// Provide provides the caller according to the credential. +func (p *CallerProvider) Provide(credential string) (*Caller, error) { + caller, ok := p.callers.Get(credential) + if ok { + return caller, nil } - // source - source, err := s.createSource() + caller, err := p.provideFunc(credential, p.zipperAddr, p.lp, p.exFn) if err != nil { - ylog.Error("create fc-service source failed", "err", err) return nil, err } - s.source = source - // reducer - reducer, err := s.createReducer() - if err != nil { - ylog.Error("create fc-service reducer failed", "err", err) - return nil, err - } - s.reducer = reducer - return s, nil -} + p.callers.Add(credential, caller) -// SetSystemPrompt sets the system prompt -func (s *Service) SetSystemPrompt(prompt string) { - s.systemPrompt.Store(prompt) + return caller, nil } -// Release releases the resources -func (s *Service) Release() { - ylog.Debug("release AI service", "credential", s.credential) - if s.source != nil { - s.source.Close() - } - if s.reducer != nil { - s.reducer.Close() - } +// Caller calls the invoke function and the chat completion function. +type Caller struct { + CallSyncer + + credential string + md metadata.M + systemPrompt atomic.Value + provider provider.LLMProvider } -func (s *Service) createSource() (yomo.Source, error) { - ylog.Debug("create fc-service source", "zipperAddr", s.zipperAddr, "credential", s.credential) +// NewCaller returns a new caller. +func NewCaller(credential string, zipperAddr string, provider provider.LLMProvider, exFn ExchangeMetadataFunc) (*Caller, error) { source := yomo.NewSource( "fc-source", - s.zipperAddr, + zipperAddr, yomo.WithSourceReConnect(), - yomo.WithCredential(s.credential), + yomo.WithCredential(credential), ) - // create ai source err := source.Connect() if err != nil { return nil, err } - return source, nil -} -// createReducer creates the reducer-sfn. reducer-sfn used to aggregate all the llm-sfn execute results. -func (s *Service) createReducer() (yomo.StreamFunction, error) { - sfn := yomo.NewStreamFunction( + reducer := yomo.NewStreamFunction( "ai-reducer", - s.zipperAddr, + zipperAddr, yomo.WithSfnReConnect(), - yomo.WithSfnCredential(s.credential), + yomo.WithSfnCredential(credential), ) - sfn.SetObserveDataTags(ai.ReducerTag) - sfn.SetHandler(func(ctx serverless.Context) { - buf := ctx.Data() - ylog.Debug("[sfn-reducer]", "tag", ai.ReducerTag, "data", string(buf)) - invoke, err := ctx.LLMFunctionCall() - if err != nil { - ylog.Error("[sfn-reducer] parse function calling invoke", "err", err.Error()) - return - } - - reqID := invoke.ReqID - - // write parallel function calling results to cache, after all the results are written, the reducer will be done - s.muCallCache.Lock() - c, ok := s.sfnCallCache[reqID] - s.muCallCache.Unlock() - if !ok { - ylog.Error("[sfn-reducer] req_id not found", "trans_id", invoke.TransID, "req_id", reqID) - return - } - - c.mu.Lock() - defer c.mu.Unlock() - - // need lock c.val as multiple handler channel will write to it - c.val[invoke.ToolCallID] = ai.ToolMessage{ - Content: invoke.Result, - ToolCallId: invoke.ToolCallID, - } - ylog.Debug("[sfn-reducer] generate", "ToolMessage", fmt.Sprintf("%+v", c.val)) + reducer.SetObserveDataTags(ai.ReducerTag) - c.wg.Done() - }) + // this line must before `Connect()`, because it should sets hander before connect. + callSyncer := NewCallSyncer(slog.Default(), source, reducer, 60*time.Second) - err := sfn.Connect() - if err != nil { + if err := reducer.Connect(); err != nil { return nil, err } - return sfn, nil -} -// GetOverview returns the overview of the AI functions, key is the tag, value is the function definition -func (s *Service) GetOverview() (*ai.OverviewResponse, error) { - tcs, err := register.ListToolCalls(s.Metadata) + md, err := exFn(credential) if err != nil { - return &ai.OverviewResponse{}, err + return nil, err } - functions := make(map[uint32]*openai.FunctionDefinition) - for tag, tc := range tcs { - functions[tag] = tc.Function + caller := &Caller{ + CallSyncer: callSyncer, + credential: credential, + md: md, + provider: provider, } - return &ai.OverviewResponse{Functions: functions}, nil + caller.SetSystemPrompt("") + + return caller, nil +} + +// SetSystemPrompt sets the system prompt +func (c *Caller) SetSystemPrompt(prompt string) { + c.systemPrompt.Store(prompt) +} + +// Metadata returns the metadata of caller. +func (c *Caller) Metadata() metadata.M { + return c.md } // GetInvoke returns the invoke response -func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, transID string, includeCallStack bool) (*ai.InvokeResponse, error) { +func (c *Caller) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, transID string, includeCallStack bool) (*ai.InvokeResponse, error) { // read tools attached to the metadata - tcs, err := register.ListToolCalls(s.Metadata) + tcs, err := register.ListToolCalls(c.md) if err != nil { return &ai.InvokeResponse{}, err } @@ -227,7 +161,7 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys promptUsage int completionUsage int ) - chatCompletionResponse, err := s.LLMProvider.GetChatCompletions(ctx, req, s.Metadata) + chatCompletionResponse, err := c.provider.GetChatCompletions(ctx, req, c.md) if err != nil { return nil, err } @@ -241,7 +175,7 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys return nil, err } // if no tool_calls fired, just return the llm text result - if !(res.FinishReason == "tool_calls" || res.FinishReason == "gemini_tool_calls") { + if res.FinishReason != string(openai.FinishReasonToolCalls) { return res, nil } @@ -251,20 +185,23 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys "res_assistant_msgs", fmt.Sprintf("%+v", res.AssistantMessage)) ylog.Debug(">> run function calls", "transID", transID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls)) - llmCalls, err := s.runFunctionCalls(res.ToolCalls, transID, id.New(16)) + + reqID := id.New(16) + llmCalls, err := c.Call(ctx, transID, reqID, res.ToolCalls) if err != nil { return nil, err } ylog.Debug(">>>> start 2nd call with", "calls", fmt.Sprintf("%+v", llmCalls), "preceeding_assistant_message", fmt.Sprintf("%+v", res.AssistantMessage)) + chainMessage.PreceedingAssistantMessage = res.AssistantMessage - chainMessage.ToolMessages = llmCalls + chainMessage.ToolMessages = transToolMessage(llmCalls) // do not attach toolMessage to prompt in 2nd call messages2 := prepareMessages(baseSystemMessage, userInstruction, chainMessage, tools, false) req2 := openai.ChatCompletionRequest{ Messages: messages2, } - chatCompletionResponse2, err := s.LLMProvider.GetChatCompletions(ctx, req2, s.Metadata) + chatCompletionResponse2, err := c.provider.GetChatCompletions(ctx, req2, c.md) if err != nil { return nil, err } @@ -280,59 +217,17 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys // INFO: call stack infomation if includeCallStack { res2.ToolCalls = res.ToolCalls - res2.ToolMessages = llmCalls + res2.ToolMessages = transToolMessage(llmCalls) } ylog.Debug("<<<< complete 2nd call", "res2", fmt.Sprintf("%+v", res2)) return res2, err } -func addToolsToRequest(req openai.ChatCompletionRequest, tagTools map[uint32]openai.Tool) openai.ChatCompletionRequest { - toolCalls := prepareToolCalls(tagTools) - - if len(toolCalls) > 0 { - req.Tools = toolCalls - } - - ylog.Debug(" #1 first call", "request", fmt.Sprintf("%+v", req)) - - return req -} - -func overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest { - // do nothing if system prompt is empty - if sysPrompt == "" { - return req - } - // over write system prompt - isOverWrite := false - for i, msg := range req.Messages { - if msg.Role != "system" { - continue - } - req.Messages[i] = openai.ChatCompletionMessage{ - Role: msg.Role, - Content: sysPrompt, - } - isOverWrite = true - } - // append system prompt - if !isOverWrite { - req.Messages = append(req.Messages, openai.ChatCompletionMessage{ - Role: "system", - Content: sysPrompt, - }) - } - - ylog.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req)) - - return req -} - -// GetChatCompletions returns the llm api response -func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, transID string, w http.ResponseWriter, includeCallStack bool) error { +// GetChatCompletions accepts openai.ChatCompletionRequest and responds to http.ResponseWriter. +func (c *Caller) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, transID string, w http.ResponseWriter) error { // 1. find all hosting tool sfn - tagTools, err := register.ListToolCalls(s.Metadata) + tagTools, err := register.ListToolCalls(c.md) if err != nil { return err } @@ -340,7 +235,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet req = addToolsToRequest(req, tagTools) // 3. over write system prompt to request - req = overWriteSystemPrompt(req, s.systemPrompt.Load().(string)) + req = overWriteSystemPrompt(req, c.systemPrompt.Load().(string)) var ( promptUsage = 0 @@ -357,7 +252,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet flusher = eventFlusher(w) isFunctionCall = false ) - resStream, err := s.LLMProvider.GetChatCompletionsStream(ctx, req, s.Metadata) + resStream, err := c.provider.GetChatCompletionsStream(ctx, req, c.md) if err != nil { return err } @@ -384,7 +279,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet item, ok := toolCallsMap[index] if !ok { toolCallsMap[index] = openai.ToolCall{ - Index: &index, + Index: t.Index, ID: t.ID, Type: t.Type, Function: openai.FunctionCall{}, @@ -401,33 +296,27 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet } isFunctionCall = true } else if streamRes.Choices[0].FinishReason != openai.FinishReasonToolCalls { - _, _ = io.WriteString(w, "data: ") - _ = json.NewEncoder(w).Encode(streamRes) - _, _ = io.WriteString(w, "\n") - flusher.Flush() + _ = writeStreamEvent(w, flusher, streamRes) } } if !isFunctionCall { - io.WriteString(w, "data: [DONE]") - flusher.Flush() - return nil - } else { - toolCalls = mapToSliceTools(toolCallsMap) + return writeStreamDone(w, flusher) + } + toolCalls = mapToSliceTools(toolCallsMap) - assistantMessage = openai.ChatCompletionMessage{ - ToolCalls: toolCalls, - Role: openai.ChatMessageRoleAssistant, - } - flusher.Flush() + assistantMessage = openai.ChatCompletionMessage{ + ToolCalls: toolCalls, + Role: openai.ChatMessageRoleAssistant, } + flusher.Flush() } else { - resp, err := s.LLMProvider.GetChatCompletions(ctx, req, s.Metadata) + resp, err := c.provider.GetChatCompletions(ctx, req, c.md) if err != nil { return err } promptUsage = resp.Usage.PromptTokens completionUsage = resp.Usage.CompletionTokens - totalUsage = resp.Usage.TotalTokens + totalUsage = resp.Usage.CompletionTokens ylog.Debug(" #1 first call", "response", fmt.Sprintf("%+v", resp)) // it is a function call @@ -442,49 +331,32 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet } // 5. find sfns that hit the function call - fnCalls := make(map[uint32][]*openai.ToolCall) - // functions may be more than one - for _, call := range toolCalls { - for tag, tc := range tagTools { - if tc.Function.Name == call.Function.Name && tc.Type == call.Type { - currentCall := call - fnCalls[tag] = append(fnCalls[tag], ¤tCall) - } - } - } + fnCalls := findTagTools(tagTools, toolCalls) + // 6. run llm function calls reqID := id.New(16) - llmCalls, err := s.runFunctionCalls(fnCalls, transID, reqID) + llmCalls, err := c.Call(ctx, transID, reqID, fnCalls) if err != nil { return err } + // 7. do the second call (the second call messages are from user input, first call resopnse and sfn calls result) req.Messages = append(reqMessages, assistantMessage) - for _, tool := range llmCalls { - tm := openai.ChatCompletionMessage{ - Role: "tool", - Content: tool.Content, - ToolCallID: tool.ToolCallId, - } - req.Messages = append(req.Messages, tm) - } - // reset tools field - req.Tools = nil + req.Messages = append(req.Messages, llmCalls...) + req.Tools = nil // reset tools field ylog.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req)) if req.Stream { flusher := w.(http.Flusher) - resStream, err := s.LLMProvider.GetChatCompletionsStream(ctx, req, s.Metadata) + resStream, err := c.provider.GetChatCompletionsStream(ctx, req, c.md) if err != nil { return err } for { streamRes, err := resStream.Recv() if err == io.EOF { - io.WriteString(w, "data: [DONE]") - flusher.Flush() - return nil + return writeStreamDone(w, flusher) } if err != nil { return err @@ -494,13 +366,10 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet streamRes.Usage.CompletionTokens += completionUsage streamRes.Usage.TotalTokens += totalUsage } - _, _ = io.WriteString(w, "data: ") - _ = json.NewEncoder(w).Encode(streamRes) - _, _ = io.WriteString(w, "\n") - flusher.Flush() + _ = writeStreamEvent(w, flusher, streamRes) } } else { - resp, err := s.LLMProvider.GetChatCompletions(ctx, req, s.Metadata) + resp, err := c.provider.GetChatCompletions(ctx, req, c.md) if err != nil { return err } @@ -514,122 +383,90 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet } } -// run llm-sfn function calls -func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, transID, reqID string) ([]ai.ToolMessage, error) { - if len(fns) == 0 { - return nil, nil - } +// ExchangeMetadataFunc is used to exchange metadata +type ExchangeMetadataFunc func(credential string) (metadata.M, error) - asyncCall := &sfnAsyncCall{ - val: make(map[string]ai.ToolMessage), - } +// DefaultExchangeMetadataFunc is the default ExchangeMetadataFunc, It returns an empty metadata. +func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) { + return metadata.M{}, nil +} - s.muCallCache.Lock() - s.sfnCallCache[reqID] = asyncCall - s.muCallCache.Unlock() +func addToolsToRequest(req openai.ChatCompletionRequest, tagTools map[uint32]openai.Tool) openai.ChatCompletionRequest { + toolCalls := prepareToolCalls(tagTools) - firedToolIDs := make(map[string]struct{}) - for tag, tcs := range fns { - ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "transID", transID, "reqID", reqID) - for _, fn := range tcs { - err := s.fireLlmSfn(tag, fn, transID, reqID) - if err != nil { - ylog.Error("send data to zipper", "err", err.Error()) - continue - } - // record toolIDs be fired - firedToolIDs[fn.ID] = struct{}{} - // wait for this request to be done - asyncCall.wg.Add(1 * register.SfnFactor(tag, s.Metadata)) - } + if len(toolCalls) > 0 { + req.Tools = toolCalls } - // wait for reducer to finish, the aggregation results - done := make(chan struct{}, 1) - go func() { - asyncCall.wg.Wait() - close(done) - }() - - select { - case <-time.After(RunFunctionTimeout): - case <-done: - } + ylog.Debug(" #1 first call", "request", fmt.Sprintf("%+v", req)) - arr := make([]ai.ToolMessage, 0) + return req +} - asyncCall.mu.RLock() - for toolID := range firedToolIDs { - call, ok := asyncCall.val[toolID] - // `ok==false` represents that the toolID be fired but not receive a response. - if !ok { - arr = append(arr, ai.ToolMessage{ - Role: "tool", - ToolCallId: toolID, - Content: "error occured in this function calling, you should ignore this.", - }) +func overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest { + // do nothing if system prompt is empty + if sysPrompt == "" { + return req + } + // over write system prompt + isOverWrite := false + for i, msg := range req.Messages { + if msg.Role != "system" { continue } - call.Role = "tool" - arr = append(arr, call) + req.Messages[i] = openai.ChatCompletionMessage{ + Role: msg.Role, + Content: sysPrompt, + } + isOverWrite = true + } + // append system prompt + if !isOverWrite { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: "system", + Content: sysPrompt, + }) } - asyncCall.mu.RUnlock() - return arr, nil -} + ylog.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req)) -// fireLlmSfn fires the llm-sfn function call by s.source.Write() -func (s *Service) fireLlmSfn(tag uint32, fn *openai.ToolCall, transID, reqID string) error { - ylog.Info( - "+invoke func", - "tag", tag, - "transID", transID, - "reqID", reqID, - "toolCallID", fn.ID, - "function", fn.Function.Name, - "arguments", fn.Function.Arguments, - ) - data := &ai.FunctionCall{ - TransID: transID, - ReqID: reqID, - ToolCallID: fn.ID, - FunctionName: fn.Function.Name, - Arguments: fn.Function.Arguments, - } - buf, err := data.Bytes() - if err != nil { - ylog.Error("marshal data", "err", err.Error()) - } - return s.source.Write(tag, buf) + return req } -// Write writes the data to zipper -func (s *Service) Write(tag uint32, data []byte) error { - return s.source.Write(tag, data) +func findTagTools(tagTools map[uint32]openai.Tool, toolCalls []openai.ToolCall) map[uint32][]*openai.ToolCall { + fnCalls := make(map[uint32][]*openai.ToolCall) + // functions may be more than one + for _, call := range toolCalls { + for tag, tc := range tagTools { + if tc.Function.Name == call.Function.Name && tc.Type == call.Type { + currentCall := call + fnCalls[tag] = append(fnCalls[tag], ¤tCall) + } + } + } + return fnCalls } -func init() { - onEvicted := func(_ string, v *Service) { - v.Release() +func writeStreamEvent(w http.ResponseWriter, flusher http.Flusher, streamRes openai.ChatCompletionStreamResponse) error { + if _, err := io.WriteString(w, "data: "); err != nil { + return err } - services = expirable.NewLRU(ServiceCacheSize, onEvicted, ServiceCacheTTL) -} + if err := json.NewEncoder(w).Encode(streamRes); err != nil { + return err + } + if _, err := io.WriteString(w, "\n"); err != nil { + return err + } + flusher.Flush() -type sfnAsyncCall struct { - wg sync.WaitGroup - mu sync.RWMutex - val map[string]ai.ToolMessage + return nil } -func prepareToolCalls(tcs map[uint32]openai.Tool) []openai.Tool { - // prepare tools - toolCalls := make([]openai.Tool, len(tcs)) - idx := 0 - for _, tc := range tcs { - toolCalls[idx] = tc - idx++ - } - return toolCalls +func writeStreamDone(w http.ResponseWriter, flusher http.Flusher) error { + _, err := io.WriteString(w, "data: [DONE]") + flusher.Flush() + + return err } func prepareMessages(baseSystemMessage string, userInstruction string, chainMessage ai.ChainMessage, tools []openai.Tool, withTool bool) []openai.ChatCompletionMessage { @@ -671,7 +508,7 @@ func prepareMessages(baseSystemMessage string, userInstruction string, chainMess tm := openai.ChatCompletionMessage{ Role: "tool", Content: tool.Content, - ToolCallID: tool.ToolCallId, + ToolCallID: tool.ToolCallID, } ylog.Debug("======== add toolMessage", "tm", fmt.Sprintf("%+v", tm)) messages = append(messages, tm) @@ -700,3 +537,26 @@ func eventFlusher(w http.ResponseWriter) http.Flusher { flusher := w.(http.Flusher) return flusher } + +func prepareToolCalls(tcs map[uint32]openai.Tool) []openai.Tool { + // prepare tools + toolCalls := make([]openai.Tool, len(tcs)) + idx := 0 + for _, tc := range tcs { + toolCalls[idx] = tc + idx++ + } + return toolCalls +} + +func transToolMessage(msgs []openai.ChatCompletionMessage) []ai.ToolMessage { + toolMessages := make([]ai.ToolMessage, len(msgs)) + for i, msg := range msgs { + toolMessages[i] = ai.ToolMessage{ + Role: msg.Role, + Content: msg.Content, + ToolCallID: msg.ToolCallID, + } + } + return toolMessages +} diff --git a/pkg/bridge/ai/caller_test.go b/pkg/bridge/ai/caller_test.go new file mode 100644 index 000000000..e19af55dc --- /dev/null +++ b/pkg/bridge/ai/caller_test.go @@ -0,0 +1,500 @@ +package ai + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + + "github.com/hashicorp/golang-lru/v2/expirable" + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/bridge/ai/register" +) + +func TestCallerInvoke(t *testing.T) { + type args struct { + providerMockData []provider.MockData + mockCallReqResp map[uint32][]mockFunctionCall + systemPrompt string + userInstruction string + baseSystemMessage string + } + tests := []struct { + name string + args args + wantRequest []openai.ChatCompletionRequest + wantUsage ai.TokenUsage + }{ + { + name: "invoke with tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(toolCallResp, stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "this is a system prompt", + userInstruction: "hi", + baseSystemMessage: "this is a base system message", + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n- \n\n"}, + {Role: "user", Content: "hi"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, + {Role: "user", Content: "hi"}, + }, + }, + }, + wantUsage: ai.TokenUsage{PromptTokens: 95, CompletionTokens: 43}, + }, + { + name: "invoke without tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{}, + systemPrompt: "this is a system prompt", + userInstruction: "hi", + baseSystemMessage: "this is a base system message", + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, + {Role: "user", Content: "hi"}, + }, + }, + }, + wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + register.SetRegister(register.NewDefault()) + + pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) + if err != nil { + t.Fatal(err) + } + + cp := newMockCallerProvider() + + cp.provideFunc = mockCallerProvideFunc(tt.args.mockCallReqResp, pd) + + caller, err := cp.Provide("") + assert.NoError(t, err) + + caller.SetSystemPrompt(tt.args.systemPrompt) + + resp, err := caller.GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", true) + assert.NoError(t, err) + + assert.Equal(t, tt.wantUsage, resp.TokenUsage) + assert.Equal(t, tt.wantRequest, pd.RequestRecords()) + }) + } +} + +func TestCallerChatCompletion(t *testing.T) { + type args struct { + providerMockData []provider.MockData + mockCallReqResp map[uint32][]mockFunctionCall + systemPrompt string + request openai.ChatCompletionRequest + } + tests := []struct { + name string + args args + wantRequest []openai.ChatCompletionRequest + }{ + { + name: "chat with tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(toolCallResp, stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "this is a system prompt", + request: openai.ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "this is a system prompt"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "this is a system prompt"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, + }, + }, + }, + }, + { + name: "chat without tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are an assistant.", + request: openai.ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How are you"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How are you"}, + {Role: "system", Content: "You are an assistant."}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + }, + }, + { + name: "chat with tool call in stream", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionStreamResponse(toolCallStreamResp, stopStreamResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are a weather assistant", + request: openai.ChatCompletionRequest{ + Stream: true, + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{Index: toInt(0), ID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\"location\":\"Boston, MA\"}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_9ctHOJqO3bYrpm2A6S7nHd5k"}, + }, + }, + }, + }, + { + name: "chat without tool call in stream", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionStreamResponse(stopStreamResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are a weather assistant", + request: openai.ChatCompletionRequest{ + Stream: true, + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + register.SetRegister(register.NewDefault()) + + pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) + if err != nil { + t.Fatal(err) + } + + cp := newMockCallerProvider() + + cp.provideFunc = mockCallerProvideFunc(tt.args.mockCallReqResp, pd) + + caller, err := cp.Provide("") + assert.NoError(t, err) + + caller.SetSystemPrompt(tt.args.systemPrompt) + + w := httptest.NewRecorder() + err = caller.GetChatCompletions(context.TODO(), tt.args.request, "transID", w) + assert.NoError(t, err) + + assert.Equal(t, tt.wantRequest, pd.RequestRecords()) + }) + } +} + +func newMockCallerProvider() *CallerProvider { + cp := &CallerProvider{ + zipperAddr: DefaultZipperAddr, + exFn: DefaultExchangeMetadataFunc, + callers: expirable.NewLRU(CallerProviderCacheSize, func(_ string, caller *Caller) { caller.Close() }, CallerProviderCacheTTL), + } + return cp +} + +// mockCallerProvideFunc returns a mock caller provider, which is used for mockCallerProvider +// the request-response of caller be provided has been defined in advance, the request and response are defined in the `calls`. +func mockCallerProvideFunc(calls map[uint32][]mockFunctionCall, p provider.LLMProvider) provideFunc { + // register function to register + for tag, call := range calls { + for _, c := range call { + register.RegisterFunction(tag, &openai.FunctionDefinition{Name: c.functionName}, uint64(tag), nil) + } + } + + return func(credential, _ string, provider provider.LLMProvider, _ ExchangeMetadataFunc) (*Caller, error) { + caller := &Caller{ + credential: credential, + provider: p, + md: metadata.M{"hello": "llm bridge"}, + } + + caller.SetSystemPrompt("") + caller.CallSyncer = &mockCallSyncer{calls: calls} + + return caller, nil + } +} + +type mockFunctionCall struct { + toolID string + functionName string + respContent string +} + +type mockCallSyncer struct { + calls map[uint32][]mockFunctionCall +} + +// Call implements CallSyncer, it returns the mock response defined in advance. +func (m *mockCallSyncer) Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) { + res := []openai.ChatCompletionMessage{} + for tag, calls := range toolCalls { + mcs, ok := m.calls[tag] + if !ok { + return nil, errors.New("call not found") + } + mcm := make(map[string]mockFunctionCall, len(mcs)) + for _, mc := range mcs { + mcm[mc.toolID] = mc + } + for _, call := range calls { + mc, ok := mcm[call.ID] + if !ok { + return nil, errors.New("call not found") + } + res = append(res, openai.ChatCompletionMessage{ + ToolCallID: mc.toolID, + Role: openai.ChatMessageRoleTool, + Content: mc.respContent, + }) + } + } + return res, nil +} + +func (m *mockCallSyncer) Close() error { return nil } + +func toInt(val int) *int { return &val } + +var stopStreamResp = `data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" just"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" a"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" computer"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" program"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" so"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" don't"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" have"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" feelings"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" but"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" here"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" and"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" ready"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" whatever"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" need"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":34,"total_tokens":47}} + +data: [DONE]` + +var stopResp = `{ + "id": "chatcmpl-9blYknv9rHvr2dvCQKMeW21hlBpCX", + "object": "chat.completion", + "created": 1718787982, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm just a computer program, so I don't have feelings, but thanks for asking. How can I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 26, + "total_tokens": 39 + }, + "system_fingerprint": "fp_f4e629d0a5" +}` + +var toolCallStreamResp = `data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_9ctHOJqO3bYrpm2A6S7nHd5k","type":"function","function":{"name":"get_current_weather","arguments":""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" MA"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[],"usage":{"prompt_tokens":83,"completion_tokens":17,"total_tokens":100}}` + +var toolCallResp = `{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-4-turbo-2024-04-09", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } +}` diff --git a/pkg/bridge/ai/provider/azopenai/provider.go b/pkg/bridge/ai/provider/azopenai/provider.go index 135bfd980..57417cc38 100644 --- a/pkg/bridge/ai/provider/azopenai/provider.go +++ b/pkg/bridge/ai/provider/azopenai/provider.go @@ -10,8 +10,7 @@ import ( "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" - "github.com/yomorun/yomo/pkg/bridge/ai" - bridgeai "github.com/yomorun/yomo/pkg/bridge/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" ) // Provider is the provider for Azure OpenAI @@ -23,7 +22,7 @@ type Provider struct { client *openai.Client } -var _ bridgeai.LLMProvider = &Provider{} +var _ provider.LLMProvider = &Provider{} // NewProvider creates a new AzureOpenAIProvider func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *Provider { @@ -72,6 +71,6 @@ func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatComple } // GetChatCompletionsStream implements ai.LLMProvider. -func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ai.ResponseRecver, error) { +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { return p.client.CreateChatCompletionStream(ctx, req) } diff --git a/pkg/bridge/ai/provider/cfazure/provider.go b/pkg/bridge/ai/provider/cfazure/provider.go index 3cff8c87c..9413bffe0 100644 --- a/pkg/bridge/ai/provider/cfazure/provider.go +++ b/pkg/bridge/ai/provider/cfazure/provider.go @@ -7,13 +7,11 @@ import ( "context" "fmt" - _ "github.com/joho/godotenv/autoload" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" - "github.com/yomorun/yomo/pkg/bridge/ai" - bridgeai "github.com/yomorun/yomo/pkg/bridge/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" ) // Provider is the provider for Azure OpenAI @@ -27,7 +25,7 @@ type Provider struct { } // check if implements ai.Provider -var _ bridgeai.LLMProvider = &Provider{} +var _ provider.LLMProvider = &Provider{} // NewProvider creates a new AzureOpenAIProvider func NewProvider(cfEndpoint string, apiKey string, resource string, deploymentID string, apiVersion string) *Provider { @@ -62,14 +60,14 @@ func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatComple } // GetChatCompletionsStream implements ai.LLMProvider. -func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ai.ResponseRecver, error) { +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { return p.client.CreateChatCompletionStream(ctx, req) } func newConfig(cfEndpoint string, apiKey string, resource string, deploymentID string, apiVersion string) openai.ClientConfig { - baseUrl := fmt.Sprintf("%s/azure-openai/%s/%s", cfEndpoint, resource, deploymentID) + baseURL := fmt.Sprintf("%s/azure-openai/%s/%s", cfEndpoint, resource, deploymentID) - config := openai.DefaultAzureConfig(apiKey, baseUrl) + config := openai.DefaultAzureConfig(apiKey, baseURL) config.APIType = openai.APITypeCloudflareAzure config.APIVersion = apiVersion diff --git a/pkg/bridge/ai/provider/cfopenai/provider.go b/pkg/bridge/ai/provider/cfopenai/provider.go index bd77879a3..7fa61f65c 100644 --- a/pkg/bridge/ai/provider/cfopenai/provider.go +++ b/pkg/bridge/ai/provider/cfopenai/provider.go @@ -12,8 +12,7 @@ import ( "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" - "github.com/yomorun/yomo/pkg/bridge/ai" - bridgeai "github.com/yomorun/yomo/pkg/bridge/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" ) // Provider is the provider for Cloudflare OpenAI Gateway @@ -28,7 +27,7 @@ type Provider struct { } // check if implements ai.Provider -var _ bridgeai.LLMProvider = &Provider{} +var _ provider.LLMProvider = &Provider{} // NewProvider creates a new AzureOpenAIProvider func NewProvider(cfEndpoint, apiKey, model string) *Provider { @@ -63,7 +62,7 @@ func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatComple } // GetChatCompletionsStream implements ai.LLMProvider. -func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ai.ResponseRecver, error) { +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { req.Model = p.Model return p.client.CreateChatCompletionStream(ctx, req) diff --git a/pkg/bridge/ai/provider/mock.go b/pkg/bridge/ai/provider/mock.go new file mode 100644 index 000000000..3d397dada --- /dev/null +++ b/pkg/bridge/ai/provider/mock.go @@ -0,0 +1,150 @@ +// Package provider defines the ai.Provider interface and provides a mock provider for unittest. +package provider + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/core/metadata" +) + +// Mock implements the ai.Provider interface. +// And it can be used for recording requests and mocking responses. +type Mock struct { + name string + + reqs []openai.ChatCompletionRequest + + // calling function once will return and remove one element from resp and streamResp. + resp []openai.ChatCompletionResponse + streamResp []*ChatCompletionStreamResponse +} + +// ChatCompletionStreamResponse has Recv() function so it implements the ResponseRecver interface. +type ChatCompletionStreamResponse struct { + items []openai.ChatCompletionStreamResponse +} + +// NewMock returns a mock provider. +func NewMock(name string, data ...MockData) (*Mock, error) { + p := &Mock{ + name: name, + } + if len(data) == 0 { + return p, nil + } + + for _, d := range data { + if err := d.apply(p); err != nil { + return nil, err + } + } + return p, nil +} + +// Recv implements the ResponseRecver interface. +func (m *ChatCompletionStreamResponse) Recv() (openai.ChatCompletionStreamResponse, error) { + if len(m.items) == 0 { + return openai.ChatCompletionStreamResponse{}, io.EOF + } + item := m.items[0] + m.items = m.items[1:] + return item, nil +} + +// MockData supplys mock response data to the mock provider. +type MockData interface { + apply(*Mock) error +} + +type applyFunc func(*Mock) error + +func (f applyFunc) apply(mp *Mock) error { return f(mp) } + +// MockChatCompletionResponse supplys mock response data to the mock provider. +func MockChatCompletionResponse(str ...string) MockData { + return applyFunc(func(m *Mock) error { + m.resp = make([]openai.ChatCompletionResponse, len(str)) + for i, s := range str { + if err := json.Unmarshal([]byte(s), &m.resp[i]); err != nil { + return err + } + } + return nil + }) +} + +// MockChatCompletionStreamResponse supplys mock response data in form of stream to the mock provider. +func MockChatCompletionStreamResponse(str ...string) MockData { + streamRespArr := make([]*ChatCompletionStreamResponse, len(str)) + for i, s := range str { + scanner := bufio.NewScanner(strings.NewReader(s)) + scanner.Split(bufio.ScanLines) + + var ( + err error + streamResp = new(ChatCompletionStreamResponse) + ) + for scanner.Scan() { + text := scanner.Text() + if strings.HasPrefix(text, "data: ") { + jsonStr := text[6:] + if jsonStr == "[DONE]" { + break + } + var item openai.ChatCompletionStreamResponse + if err = json.Unmarshal([]byte(jsonStr), &item); err != nil { + return applyFunc(func(m *Mock) error { + return err + }) + } + streamResp.items = append(streamResp.items, item) + } + } + streamRespArr[i] = streamResp + } + + return applyFunc(func(m *Mock) error { + m.streamResp = streamRespArr + return nil + }) +} + +// GetChatCompletions implements the ai.Provider interface. +func (m *Mock) GetChatCompletions(_ context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { + data, _ := json.Marshal(&req) + fmt.Println("[mock provider] request:", string(data)) + + m.reqs = append(m.reqs, req) + + item := m.resp[0] + m.resp = m.resp[1:] + return item, nil +} + +// GetChatCompletionsStream implements the ai.Provider interface. +func (m *Mock) GetChatCompletionsStream(_ context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ResponseRecver, error) { + data, _ := json.Marshal(&req) + fmt.Println("[mock provider] stream request:", string(data)) + + m.reqs = append(m.reqs, req) + + item := m.streamResp[0] + m.streamResp = m.streamResp[1:] + return item, nil +} + +// RequestRecords returns the request records. +func (m *Mock) RequestRecords() []openai.ChatCompletionRequest { + return m.reqs +} + +// Name returns the provider name. +func (m *Mock) Name() string { + return m.name +} diff --git a/pkg/bridge/ai/provider/mock_test.go b/pkg/bridge/ai/provider/mock_test.go new file mode 100644 index 000000000..289b3038b --- /dev/null +++ b/pkg/bridge/ai/provider/mock_test.go @@ -0,0 +1,95 @@ +package provider + +import ( + "context" + "io" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +func TestMockProviderRequest(t *testing.T) { + provider, err := NewMock("mock", + MockChatCompletionResponse(data), + MockChatCompletionStreamResponse(streamData)) + assert.NoError(t, err) + + reqs := []openai.ChatCompletionRequest{ + {Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "hi, llm bridge"}}}, + {Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "hi, yomo"}}}, + } + + provider.GetChatCompletions(context.TODO(), reqs[0], nil) + provider.GetChatCompletionsStream(context.TODO(), reqs[1], nil) + + assert.Equal(t, reqs, provider.RequestRecords()) +} + +func TestMockProvider(t *testing.T) { + provider, err := NewMock("mock", + MockChatCompletionResponse(data), + MockChatCompletionStreamResponse(streamData)) + assert.NoError(t, err) + + t.Run("Name()", func(t *testing.T) { + assert.Equal(t, "mock", provider.Name()) + }) + + t.Run("GetChatCompletions()", func(t *testing.T) { + resp, err := provider.GetChatCompletions(context.TODO(), openai.ChatCompletionRequest{Model: "gpt-4o-2024-05-13"}, nil) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl-9b9wyqGfbZHi0pPDfYgYKXAV1obkm", resp.ID) + assert.Equal(t, "gpt-4o-2024-05-13", resp.Model) + assert.Equal(t, "One plus one equals two.", resp.Choices[0].Message.Content) + }) + + t.Run("GetChatCompletionsStream()", func(t *testing.T) { + recver, err := provider.GetChatCompletionsStream(context.TODO(), openai.ChatCompletionRequest{}, nil) + assert.NoError(t, err) + + for { + resp, err := recver.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + assert.Equal(t, "chatcmpl-9b2Ah9pTmqiVMkwZEPBLqJpLzFzGG", resp.ID) + assert.Equal(t, "gpt-4o-2024-05-13", resp.Model) + } + }) +} + +var data = `{ + "id": "chatcmpl-9b9wyqGfbZHi0pPDfYgYKXAV1obkm", + "object": "chat.completion", + "created": 1718643412, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "One plus one equals two." + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 6, + "total_tokens": 19 + }, + "system_fingerprint": "fp_319be4768e" + }` + +var streamData = `data: {"id":"chatcmpl-9b2Ah9pTmqiVMkwZEPBLqJpLzFzGG","object":"chat.completion.chunk","created":1718613511,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_aa87380ac5","choices":[{"index":0,"delta":{"content":" accurate"},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-9b2Ah9pTmqiVMkwZEPBLqJpLzFzGG","object":"chat.completion.chunk","created":1718613511,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_aa87380ac5","choices":[{"index":0,"delta":{"content":" conversion"},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-9b2Ah9pTmqiVMkwZEPBLqJpLzFzGG","object":"chat.completion.chunk","created":1718613511,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_aa87380ac5","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-9b2Ah9pTmqiVMkwZEPBLqJpLzFzGG","object":"chat.completion.chunk","created":1718613511,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_aa87380ac5","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + + data: [DONE]` diff --git a/pkg/bridge/ai/provider/ollama/provider.go b/pkg/bridge/ai/provider/ollama/provider.go index 9c5c95869..2ce783ddb 100644 --- a/pkg/bridge/ai/provider/ollama/provider.go +++ b/pkg/bridge/ai/provider/ollama/provider.go @@ -1,3 +1,4 @@ +// Package ollama is used to provide the Ollama service for YoMo Bridge. package ollama import ( @@ -18,7 +19,7 @@ import ( openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" - "github.com/yomorun/yomo/pkg/bridge/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/id" ) @@ -295,7 +296,7 @@ func (s *streamResponse) Recv() (openai.ChatCompletionStreamResponse, error) { } // GetChatCompletionsStream implements ai.LLMProvider. -func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (ai.ResponseRecver, error) { +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { urlPath, err := url.JoinPath(p.Endpoint, "api/generate") if err != nil { return nil, err diff --git a/pkg/bridge/ai/provider/openai/provider.go b/pkg/bridge/ai/provider/openai/provider.go index 76baef634..6f34b18ab 100644 --- a/pkg/bridge/ai/provider/openai/provider.go +++ b/pkg/bridge/ai/provider/openai/provider.go @@ -9,14 +9,10 @@ import ( _ "github.com/joho/godotenv/autoload" "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" - "github.com/yomorun/yomo/core/ylog" - bridgeai "github.com/yomorun/yomo/pkg/bridge/ai" + provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" ) -// APIEndpoint is the endpoint for OpenAI -const APIEndpoint = "https://api.openai.com/v1/chat/completions" - // Provider is the provider for OpenAI type Provider struct { // APIKey is the API key for OpenAI @@ -28,7 +24,7 @@ type Provider struct { } // check if implements ai.Provider -var _ bridgeai.LLMProvider = &Provider{} +var _ provider.LLMProvider = &Provider{} // NewProvider creates a new OpenAIProvider func NewProvider(apiKey string, model string) *Provider { @@ -38,12 +34,15 @@ func NewProvider(apiKey string, model string) *Provider { if model == "" { model = os.Getenv("OPENAI_MODEL") } + c := openai.DefaultConfig(apiKey) + if v, ok := os.LookupEnv("OPENAI_BASE_URL"); ok { + c.BaseURL = v + } - ylog.Debug("new openai provider", "api_endpoint", APIEndpoint, "api_key", apiKey, "model", model) return &Provider{ APIKey: apiKey, Model: model, - client: openai.NewClient(apiKey), + client: openai.NewClientWithConfig(c), } } @@ -60,7 +59,7 @@ func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatComple } // GetChatCompletionsStream implements ai.LLMProvider. -func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (bridgeai.ResponseRecver, error) { +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { req.Model = p.Model return p.client.CreateChatCompletionStream(ctx, req) diff --git a/pkg/bridge/ai/provider.go b/pkg/bridge/ai/provider/provider.go similarity index 58% rename from pkg/bridge/ai/provider.go rename to pkg/bridge/ai/provider/provider.go index 236aa9568..b3befdb1a 100644 --- a/pkg/bridge/ai/provider.go +++ b/pkg/bridge/ai/provider/provider.go @@ -1,13 +1,17 @@ -package ai +package provider import ( "context" + "errors" "sync" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" ) +// ErrNotExistsProvider is the error when the provider does not exist +var ErrNotExistsProvider = errors.New("llm provider does not exist") + // LLMProvider provides an interface to the llm providers type LLMProvider interface { // Name returns the name of the llm provider @@ -25,9 +29,7 @@ type ResponseRecver interface { } var ( - providers sync.Map - defaultProvider LLMProvider - mu sync.Mutex + providers sync.Map ) // RegisterProvider registers the llm provider @@ -47,51 +49,18 @@ func ListProviders() []string { return names } -// SetDefaultProvider sets the default llm provider -func SetDefaultProvider(name string) { - provider := GetProvider(name) - if provider != nil { - setDefaultProvider(provider) - } -} - -func setDefaultProvider(provider LLMProvider) { - mu.Lock() - defer mu.Unlock() - defaultProvider = provider -} - -// GetProvider returns the llm provider by name -func GetProvider(name string) LLMProvider { +func getProvider(name string) LLMProvider { if provider, ok := providers.Load(name); ok { return provider.(LLMProvider) } return nil } -// GetProviderAndSetDefault returns the llm provider by name and set it as the default provider -func GetProviderAndSetDefault(name string) (LLMProvider, error) { - provider := GetProvider(name) +// GetProvider returns the llm provider by name +func GetProvider(name string) (LLMProvider, error) { + provider := getProvider(name) if provider != nil { - setDefaultProvider(provider) return provider, nil } return nil, ErrNotExistsProvider } - -// GetDefaultProvider returns the default llm provider -func GetDefaultProvider() (LLMProvider, error) { - mu.Lock() - defer mu.Unlock() - if defaultProvider != nil { - return defaultProvider, nil - } - names := ListProviders() - if len(names) > 0 { - p := GetProvider(names[0]) - if p != nil { - return p, nil - } - } - return nil, ErrNotExistsProvider -} diff --git a/pkg/bridge/ai/provider/provider_test.go b/pkg/bridge/ai/provider/provider_test.go new file mode 100644 index 000000000..bd55bdc18 --- /dev/null +++ b/pkg/bridge/ai/provider/provider_test.go @@ -0,0 +1,35 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProviders(t *testing.T) { + p1, _ := NewMock("name-1") + p2, _ := NewMock("name-2") + p3, _ := NewMock("name-3") + + RegisterProvider(p1) + RegisterProvider(p2) + RegisterProvider(p3) + + t.Run("ListProviders", func(t *testing.T) { + val := ListProviders() + expected := []string{"name-1", "name-2", "name-3"} + assert.ElementsMatch(t, expected, val) + }) + + t.Run("GetProvider error", func(t *testing.T) { + _, err := GetProvider("name-not-exist") + assert.ErrorIs(t, err, ErrNotExistsProvider) + }) + + t.Run("GetProvider", func(t *testing.T) { + p, err := GetProvider("name-1") + assert.NoError(t, err) + assert.Equal(t, p1, p) + }) + +} diff --git a/pkg/bridge/ai/register/register.go b/pkg/bridge/ai/register/register.go index 48247e3d7..19927d6d8 100644 --- a/pkg/bridge/ai/register/register.go +++ b/pkg/bridge/ai/register/register.go @@ -33,6 +33,11 @@ func GetRegister() Register { return defaultRegister } +// NewDefault creates a new default register. +func NewDefault() Register { + return ®ister{} +} + // ListToolCalls returns the list of tool calls func ListToolCalls(md metadata.M) (map[uint32]openai.Tool, error) { return defaultRegister.ListToolCalls(md) @@ -48,11 +53,6 @@ func UnregisterFunction(connID uint64, md metadata.M) { defaultRegister.UnregisterFunction(connID, md) } -// SfnFactor returns the sfn factor -func SfnFactor(tag uint32, md metadata.M) int { - return defaultRegister.SfnFactor(tag, md) -} - type connectedFn struct { connID uint64 tag uint32 @@ -67,8 +67,6 @@ type Register interface { RegisterFunction(tag uint32, functionDefinition *openai.FunctionDefinition, connID uint64, md metadata.M) error // UnregisterFunction unregisters a function calling function UnregisterFunction(connID uint64, md metadata.M) - // SfnFactor returns the sfn factor - SfnFactor(tag uint32, md metadata.M) int } type register struct { @@ -103,16 +101,3 @@ func (r *register) RegisterFunction(tag uint32, functionDefinition *ai.FunctionD func (r *register) UnregisterFunction(connID uint64, _ metadata.M) { r.underlying.Delete(connID) } - -// SfnFactor returns the sfn factor -func (r *register) SfnFactor(tag uint32, md metadata.M) int { - factor := 0 - r.underlying.Range(func(key, value any) bool { - fn := value.(*connectedFn) - if fn.tag == tag { - factor++ - } - return true - }) - return factor -} diff --git a/pkg/bridge/ai/register/register_test.go b/pkg/bridge/ai/register/register_test.go index 15b799a8b..53ffcb908 100644 --- a/pkg/bridge/ai/register/register_test.go +++ b/pkg/bridge/ai/register/register_test.go @@ -6,11 +6,10 @@ import ( "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/yomorun/yomo/ai" - "github.com/yomorun/yomo/core/metadata" ) func TestRegister(t *testing.T) { - r := ®ister{} + r := NewDefault() SetRegister(r) assert.Equal(t, r, GetRegister()) @@ -53,29 +52,3 @@ func assertToolCalls(t *testing.T, wantTag uint32, want *ai.FunctionDefinition, assert.Equal(t, wantTag, tag) assert.Equal(t, want, got.Function) } - -func TestSfnFactor(t *testing.T) { - functionDefinition := &ai.FunctionDefinition{ - Name: "function1", - Description: "desc1", - Parameters: &ai.FunctionParameters{ - Type: "type1", - Properties: map[string]*ai.ParameterProperty{ - "prop1": {Type: "type1", Description: "desc1"}, - "prop2": {Type: "type2", Description: "desc2"}, - }, - Required: []string{"prop1"}, - }, - } - RegisterFunction(1, functionDefinition, 1, nil) - assert.Equal(t, 1, SfnFactor(1, nil)) - - RegisterFunction(1, functionDefinition, 2, metadata.M{}) - assert.Equal(t, 2, SfnFactor(1, metadata.M{})) - - UnregisterFunction(1, nil) - assert.Equal(t, 1, SfnFactor(1, nil)) - - UnregisterFunction(2, metadata.M{}) - assert.Equal(t, 0, SfnFactor(1, metadata.M{})) -} diff --git a/pkg/listener/quic/quic.go b/pkg/listener/quic/quic.go index 0963edb48..07eaea6b0 100644 --- a/pkg/listener/quic/quic.go +++ b/pkg/listener/quic/quic.go @@ -1,3 +1,4 @@ +// Package yquic provides a quic implementation of yomo.FrameConn. package yquic import ( diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go index cb8d94187..8cbbb2b92 100644 --- a/pkg/trace/trace.go +++ b/pkg/trace/trace.go @@ -1,3 +1,4 @@ +// Package trace provides otel span tracer for YoMo. package trace import ( diff --git a/serverless/guest/context.go b/serverless/guest/context.go index 49485b4c6..62b1d11c4 100644 --- a/serverless/guest/context.go +++ b/serverless/guest/context.go @@ -1,4 +1,4 @@ -// guest wasm application programming interface for guest module +// Package guest wasm application programming interface for guest module package guest import ( diff --git a/serverless/mock/mock_context.go b/serverless/mock/mock_context.go index 9d079c200..d3d4d8d34 100644 --- a/serverless/mock/mock_context.go +++ b/serverless/mock/mock_context.go @@ -1,3 +1,4 @@ +// Package mock provides a mock context for stream function. package mock import ( @@ -53,7 +54,7 @@ func (c *MockContext) Metadata(_ string) (string, bool) { } // HTTP returns the HTTP interface.H -func (m *MockContext) HTTP() serverless.HTTP { +func (c *MockContext) HTTP() serverless.HTTP { panic("not implemented, to use `net/http` package") }