Skip to content

Commit

Permalink
refactor: use goroutinue instead of mutex to sync messages (#842)
Browse files Browse the repository at this point in the history
# Description

The HTTP server startup creates a `CallerProvider` struct, which
provides a `Provide(credential string)` method that returns a `Caller`.

The `Caller` and `provider` work together to provide invoke and chat
completion functionality.

The `Caller` has a `CallerSyncer` attribute, which is responsible for
synchronizing pub/sub information to the req/res.

## Impact

The register no longer requires the Factor method, which ensures the
correct execution of yomo test-prompt.

On the server-side, it is guaranteed that `llm-sfn` can only have one
return, and if there are multiple returns, the first one will be used.
If there is no return, the client will be notified of a timeout.

# Test

Provide a mock `provider` which can record the requests to OpenAI and
mock the responses from OpenAI.

Provide a `mockDataFlow` which is used to mock the data flow in Yomo,
from the source -> handler -> reducer.
  • Loading branch information
woorui committed Jun 28, 2024
1 parent 016bdb7 commit 273b19a
Show file tree
Hide file tree
Showing 33 changed files with 1,748 additions and 831 deletions.
14 changes: 7 additions & 7 deletions ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ type InvokeRequest struct {
// InvokeResponse is the response for chat completions
type InvokeResponse struct {
// Content is the content from llm api response
Content string
Content string `json:"content,omitempty"`
// ToolCalls is the toolCalls from llm api response
ToolCalls map[uint32][]*openai.ToolCall
ToolCalls map[uint32][]*openai.ToolCall `json:"tool_calls,omitempty"`
// ToolMessages is the tool messages from llm api response
ToolMessages []ToolMessage
ToolMessages []ToolMessage `json:"tool_messages,omitempty"`
// FinishReason is the finish reason from llm api response
FinishReason string
FinishReason string `json:"finish_reason,omitempty"`
// TokenUsage is the token usage from llm api response
TokenUsage TokenUsage
TokenUsage TokenUsage `json:"token_usage,omitempty"`
// AssistantMessage is the assistant message from llm api response, only present when finish reason is "tool_calls"
AssistantMessage interface{}
AssistantMessage interface{} `json:"assistant_message,omitempty"`
}

// TokenUsage is the token usage in Response
Expand Down Expand Up @@ -63,7 +63,7 @@ type ParameterProperty struct {
type ToolMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id"`
ToolCallID string `json:"tool_call_id"`
}

// ChainMessage is the message for chaining llm request with preceeding `tool_calls` response
Expand Down
1 change: 1 addition & 0 deletions ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/yomorun/yomo/core/ylog"
)

// ConvertToInvokeResponse converts openai.ChatCompletionResponse struct to InvokeResponse struct.
func ConvertToInvokeResponse(res *openai.ChatCompletionResponse, tcs map[uint32]openai.Tool) (*InvokeResponse, error) {
choice := res.Choices[0]
ylog.Debug(">>finish_reason", "reason", choice.FinishReason)
Expand Down
18 changes: 9 additions & 9 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/yomorun/yomo/pkg/log"

"github.com/yomorun/yomo/pkg/bridge/ai"
providerpkg "github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai"
Expand Down Expand Up @@ -78,7 +79,7 @@ var serveCmd = &cobra.Command{
}
if aiConfig != nil {
// add AI connection middleware
options = append(options, yomo.WithZipperConnMiddleware(ai.ConnMiddleware))
options = append(options, yomo.WithZipperConnMiddleware(ai.RegisterFunctionMW()))
}
// new zipper
zipper, err := yomo.NewZipper(
Expand All @@ -97,7 +98,7 @@ var serveCmd = &cobra.Command{
registerAIProvider(aiConfig)
// start the llm api server
go func() {
err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString))
err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString), ylog.Default())
if err != nil {
log.FailureStatusEvent(os.Stdout, err.Error())
return
Expand All @@ -118,37 +119,36 @@ func registerAIProvider(aiConfig *ai.Config) error {
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
providerpkg.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
providerpkg.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
case "cloudflare_azure":
ai.RegisterProvider(cfazure.NewProvider(
providerpkg.RegisterProvider(cfazure.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["resource"],
provider["deployment_id"],
provider["api_version"],
))
case "cloudflare_openai":
ai.RegisterProvider(cfopenai.NewProvider(
providerpkg.RegisterProvider(cfopenai.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["model"],
))
case "ollama":
ai.RegisterProvider(ollama.NewProvider(provider["api_endpoint"]))
providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}

// log.InfoStatusEvent(os.Stdout, "registered [%d] AI provider", len(ai.ListProviders()))
ylog.Info("registered AI providers", "len", len(ai.ListProviders()))
ylog.Info("registered AI providers", "len", len(providerpkg.ListProviders()))
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion cli/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ var testPromptCmd = &cobra.Command{
func getToolCallResult(tc *openai.ToolCall, tms []ai.ToolMessage) string {
result := ""
for _, tm := range tms {
if tm.ToolCallId == tc.ID {
if tm.ToolCallID == tc.ID {
result = tm.Content
}
}
Expand Down
3 changes: 2 additions & 1 deletion core/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type Connection struct {
Logger *slog.Logger
}

func newConnection(
// NewConnection creates a new connection according to the parameters.
func NewConnection(
id uint64,
name string, clientID string, clientType ClientType, md metadata.M, tags []uint32,
fconn frame.Conn, logger *slog.Logger,
Expand Down
2 changes: 1 addition & 1 deletion core/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestConnection(t *testing.T) {
md metadata.M
)

connection := newConnection(id, name, clientID, styp, md, observed, nil, ylog.Default())
connection := NewConnection(id, name, clientID, styp, md, observed, nil, ylog.Default())

t.Run("ConnectionInfo", func(t *testing.T) {
assert.Equal(t, id, connection.ID())
Expand Down
2 changes: 1 addition & 1 deletion core/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ func TestConnector(t *testing.T) {
// mockConn returns a connection that only includes an ID and a name.
// This function is used for unit testing purposes.
func mockConn(id uint64, name string) *Connection {
return newConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default())
return NewConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default())
}
4 changes: 2 additions & 2 deletions core/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ type defaultRouter struct {
data map[frame.Tag]map[uint64]struct{}
}

// DefaultRouter provides a default implementation of `router`,
// Default provides a default implementation of `router`,
// It routes data according to observed tag and metadata.
func Default() *defaultRouter {
func Default() Router {
return &defaultRouter{
targets: make(map[uint64]string),
data: make(map[frame.Tag]map[uint64]struct{}),
Expand Down
2 changes: 1 addition & 1 deletion core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (s *Server) createConnection(hf *frame.HandshakeFrame, md metadata.M, fconn
if hf.WantedTarget != "" {
md.Set(metadata.WantedTargetKey, hf.WantedTarget)
}
conn := newConnection(
conn := NewConnection(
incrID(),
hf.Name,
hf.ID,
Expand Down
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ var (
}
}

// WithVersionNegotiateFunc sets the version negotiate function for the zipper
WithVersionNegotiateFunc = func(f core.VersionNegotiateFunc) ZipperOption {
return func(zo *zipperOptions) {
zo.serverOption = append(zo.serverOption, core.WithVersionNegotiateFunc(f))
Expand Down
65 changes: 35 additions & 30 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,48 @@ var (
ErrConfigFormatError = errors.New("ai config format is incorrect")
)

// ConnMiddleware returns a ConnMiddleware that can be used to intercept the connection.
func ConnMiddleware(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
definition, ok := connMd.Get(ai.FunctionDefinitionKey)
// RegisterFunctionMW returns a ConnMiddleware that can be used to register an ai function.
func RegisterFunctionMW() core.ConnMiddleware { return registerFunction(register.GetRegister()) }

defer func() {
// definition does not be transmitted in mesh network, It only works for handshake.
conn.Metadata().Set(ai.FunctionDefinitionKey, "")
next(conn)
if ok {
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}
}()
func registerFunction(r register.Register) core.ConnMiddleware {
return core.ConnMiddleware(func(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
definition, ok := connMd.Get(ai.FunctionDefinitionKey)

if conn.ClientType() != core.ClientTypeStreamFunction || !ok {
return
}
defer func() {
if ok {
conn.Metadata().Set(ai.FunctionDefinitionKey, "")
}
// definition does not be transmitted in mesh network, It only works for handshake.
next(conn)
if ok {
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}
}()

for _, tag := range conn.ObserveDataTags() {
// register ai function
fd := ai.FunctionDefinition{}
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
if conn.ClientType() != core.ClientTypeStreamFunction || !ok {
return
}
err = register.RegisterFunction(tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err)
return

for _, tag := range conn.ObserveDataTags() {
// register ai function
fd := ai.FunctionDefinition{}
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
return
}
err = r.RegisterFunction(tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err)
return
}
conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition))
}
conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition))
}

}
})
}

// Config is the configuration of AI bridge.
Expand Down
Loading

0 comments on commit 273b19a

Please sign in to comment.