Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use goroutinue instead of mutex to sync messages #842

Merged
merged 25 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ type InvokeRequest struct {
// InvokeResponse is the response for chat completions
type InvokeResponse struct {
// Content is the content from llm api response
Content string
Content string `json:"content,omitempty"`
// ToolCalls is the toolCalls from llm api response
ToolCalls map[uint32][]*openai.ToolCall
ToolCalls map[uint32][]*openai.ToolCall `json:"tool_calls,omitempty"`
// ToolMessages is the tool messages from llm api response
ToolMessages []ToolMessage
ToolMessages []ToolMessage `json:"tool_messages,omitempty"`
// FinishReason is the finish reason from llm api response
FinishReason string
FinishReason string `json:"finish_reason,omitempty"`
// TokenUsage is the token usage from llm api response
TokenUsage TokenUsage
TokenUsage TokenUsage `json:"token_usage,omitempty"`
// AssistantMessage is the assistant message from llm api response, only present when finish reason is "tool_calls"
AssistantMessage interface{}
AssistantMessage interface{} `json:"assistant_message,omitempty"`
}

// TokenUsage is the token usage in Response
Expand Down Expand Up @@ -63,7 +63,7 @@ type ParameterProperty struct {
type ToolMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id"`
ToolCallID string `json:"tool_call_id"`
}

// ChainMessage is the message for chaining llm request with preceeding `tool_calls` response
Expand Down
1 change: 1 addition & 0 deletions ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/yomorun/yomo/core/ylog"
)

// ConvertToInvokeResponse converts openai.ChatCompletionResponse struct to InvokeResponse struct.
func ConvertToInvokeResponse(res *openai.ChatCompletionResponse, tcs map[uint32]openai.Tool) (*InvokeResponse, error) {
choice := res.Choices[0]
ylog.Debug(">>finish_reason", "reason", choice.FinishReason)
Expand Down
18 changes: 9 additions & 9 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/yomorun/yomo/pkg/log"

"github.com/yomorun/yomo/pkg/bridge/ai"
providerpkg "github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai"
Expand Down Expand Up @@ -78,7 +79,7 @@ var serveCmd = &cobra.Command{
}
if aiConfig != nil {
// add AI connection middleware
options = append(options, yomo.WithZipperConnMiddleware(ai.ConnMiddleware))
options = append(options, yomo.WithZipperConnMiddleware(ai.RegisterFunctionMW()))
}
// new zipper
zipper, err := yomo.NewZipper(
Expand All @@ -97,7 +98,7 @@ var serveCmd = &cobra.Command{
registerAIProvider(aiConfig)
// start the llm api server
go func() {
err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString))
err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString), ylog.Default())
if err != nil {
log.FailureStatusEvent(os.Stdout, err.Error())
return
Expand All @@ -118,37 +119,36 @@ func registerAIProvider(aiConfig *ai.Config) error {
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
providerpkg.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
providerpkg.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
case "cloudflare_azure":
ai.RegisterProvider(cfazure.NewProvider(
providerpkg.RegisterProvider(cfazure.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["resource"],
provider["deployment_id"],
provider["api_version"],
))
case "cloudflare_openai":
ai.RegisterProvider(cfopenai.NewProvider(
providerpkg.RegisterProvider(cfopenai.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["model"],
))
case "ollama":
ai.RegisterProvider(ollama.NewProvider(provider["api_endpoint"]))
providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}

// log.InfoStatusEvent(os.Stdout, "registered [%d] AI provider", len(ai.ListProviders()))
ylog.Info("registered AI providers", "len", len(ai.ListProviders()))
ylog.Info("registered AI providers", "len", len(providerpkg.ListProviders()))
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion cli/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ var testPromptCmd = &cobra.Command{
func getToolCallResult(tc *openai.ToolCall, tms []ai.ToolMessage) string {
result := ""
for _, tm := range tms {
if tm.ToolCallId == tc.ID {
if tm.ToolCallID == tc.ID {
result = tm.Content
}
}
Expand Down
3 changes: 2 additions & 1 deletion core/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type Connection struct {
Logger *slog.Logger
}

func newConnection(
// NewConnection creates a new connection according to the parameters.
func NewConnection(
id uint64,
name string, clientID string, clientType ClientType, md metadata.M, tags []uint32,
fconn frame.Conn, logger *slog.Logger,
Expand Down
2 changes: 1 addition & 1 deletion core/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestConnection(t *testing.T) {
md metadata.M
)

connection := newConnection(id, name, clientID, styp, md, observed, nil, ylog.Default())
connection := NewConnection(id, name, clientID, styp, md, observed, nil, ylog.Default())

t.Run("ConnectionInfo", func(t *testing.T) {
assert.Equal(t, id, connection.ID())
Expand Down
2 changes: 1 addition & 1 deletion core/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ func TestConnector(t *testing.T) {
// mockConn returns a connection that only includes an ID and a name.
// This function is used for unit testing purposes.
func mockConn(id uint64, name string) *Connection {
return newConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default())
return NewConnection(id, name, "mock-id", ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default())
}
4 changes: 2 additions & 2 deletions core/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ type defaultRouter struct {
data map[frame.Tag]map[uint64]struct{}
}

// DefaultRouter provides a default implementation of `router`,
// Default provides a default implementation of `router`,
// It routes data according to observed tag and metadata.
func Default() *defaultRouter {
func Default() Router {
return &defaultRouter{
targets: make(map[uint64]string),
data: make(map[frame.Tag]map[uint64]struct{}),
Expand Down
2 changes: 1 addition & 1 deletion core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (s *Server) createConnection(hf *frame.HandshakeFrame, md metadata.M, fconn
if hf.WantedTarget != "" {
md.Set(metadata.WantedTargetKey, hf.WantedTarget)
}
conn := newConnection(
conn := NewConnection(
incrID(),
hf.Name,
hf.ID,
Expand Down
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ var (
}
}

// WithVersionNegotiateFunc sets the version negotiate function for the zipper
WithVersionNegotiateFunc = func(f core.VersionNegotiateFunc) ZipperOption {
return func(zo *zipperOptions) {
zo.serverOption = append(zo.serverOption, core.WithVersionNegotiateFunc(f))
Expand Down
65 changes: 35 additions & 30 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,48 @@
ErrConfigFormatError = errors.New("ai config format is incorrect")
)

// ConnMiddleware returns a ConnMiddleware that can be used to intercept the connection.
func ConnMiddleware(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
definition, ok := connMd.Get(ai.FunctionDefinitionKey)
// RegisterFunctionMW returns a ConnMiddleware that can be used to register an ai function.
func RegisterFunctionMW() core.ConnMiddleware { return registerFunction(register.GetRegister()) }

Check warning on line 26 in pkg/bridge/ai/ai.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/ai.go#L26

Added line #L26 was not covered by tests

defer func() {
// definition does not be transmitted in mesh network, It only works for handshake.
conn.Metadata().Set(ai.FunctionDefinitionKey, "")
next(conn)
if ok {
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}
}()
func registerFunction(r register.Register) core.ConnMiddleware {
return core.ConnMiddleware(func(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
definition, ok := connMd.Get(ai.FunctionDefinitionKey)

if conn.ClientType() != core.ClientTypeStreamFunction || !ok {
return
}
defer func() {
if ok {
conn.Metadata().Set(ai.FunctionDefinitionKey, "")
}
// definition does not be transmitted in mesh network, It only works for handshake.
next(conn)
if ok {
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}
}()

for _, tag := range conn.ObserveDataTags() {
// register ai function
fd := ai.FunctionDefinition{}
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
if conn.ClientType() != core.ClientTypeStreamFunction || !ok {
return
}
err = register.RegisterFunction(tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err)
return

for _, tag := range conn.ObserveDataTags() {
// register ai function
fd := ai.FunctionDefinition{}
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
return

Check warning on line 56 in pkg/bridge/ai/ai.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/ai.go#L55-L56

Added lines #L55 - L56 were not covered by tests
}
err = r.RegisterFunction(tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err)
return

Check warning on line 61 in pkg/bridge/ai/ai.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/ai.go#L60-L61

Added lines #L60 - L61 were not covered by tests
}
conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition))
}
conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition))
}

}
})
}

// Config is the configuration of AI bridge.
Expand Down
Loading