diff --git a/cli/serve.go b/cli/serve.go index 28d881f84..b582f7371 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -28,6 +28,7 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini" "github.com/yomorun/yomo/pkg/bridge/ai/provider/openai" ) @@ -112,26 +113,26 @@ var serveCmd = &cobra.Command{ }, } -func registerAIProvider(aiConfig *ai.Config) { - // register the AI provider - for name, provider := range aiConfig.Providers { - // register LLM provider - switch name { - case "azopenai": - ai.RegisterProvider(azopenai.NewProvider( - provider["api_key"], - provider["api_endpoint"], - provider["deployment_id"], - provider["api_version"], - )) - log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name) - // TODO: register other providers - } - // register the OpenAI provider - if name == "openai" { - ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) - } - } +func registerAIProvider(aiConfig *ai.Config) error { + for name, provider := range aiConfig.Providers { + switch name { + case "azopenai": + ai.RegisterProvider(azopenai.NewProvider( + provider["api_key"], + provider["api_endpoint"], + provider["deployment_id"], + provider["api_version"], + )) + case "gemini": + ai.RegisterProvider(gemini.NewProvider(provider["api_key"])) + case "openai": + ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) + default: + log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name) + } + } + log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name) + return nil } func init() { diff --git a/example/10-ai/sfn-timezone-calculator/main.go b/example/10-ai/sfn-timezone-calculator/main.go index da41335b6..204fce0c6 100644 --- a/example/10-ai/sfn-timezone-calculator/main.go +++ b/example/10-ai/sfn-timezone-calculator/main.go @@ -18,7 +18,7 @@ type Parameter struct { } func Description() string { - return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is 'YYYY-MM-DD HH:MM:SS'. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format 'YYYY-MM-DD HH:MM:SS'. If you are not sure about the date value of `timeString`, set date value to '1900-01-01'" + return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is `YYYY-MM-DD HH:MM:SS`. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format `YYYY-MM-DD HH:MM:SS`. If you are not sure about the date value of `timeString`, set date value to `1900-01-01`" } func InputSchema() any { @@ -78,6 +78,7 @@ func handler(ctx serverless.Context) { targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone) if err != nil { slog.Error("[sfn] ConvertTimezone error", "err", err) + fcCtx.WriteErrors(err) return } diff --git a/pkg/bridge/ai/provider/gemini/model_converter.go b/pkg/bridge/ai/provider/gemini/model_converter.go new file mode 100644 index 000000000..f16a6fbec --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/model_converter.go @@ -0,0 +1,131 @@ +package gemini + +import ( + "encoding/json" + + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/ylog" +) + +func convertStandardToFunctionDeclaration(functionDefinition *ai.FunctionDefinition) *FunctionDeclaration { + if functionDefinition == nil { + return nil + } + + return &FunctionDeclaration{ + Name: functionDefinition.Name, + Description: functionDefinition.Description, + Parameters: convertStandardToFunctionParameters(functionDefinition.Parameters), + } +} + +func convertFunctionDeclarationToStandard(functionDefinition *FunctionDeclaration) *ai.FunctionDefinition { + if functionDefinition == nil { + return nil + } + + return &ai.FunctionDefinition{ + Name: functionDefinition.Name, + Description: functionDefinition.Description, + Parameters: convertFunctionParametersToStandard(functionDefinition.Parameters), + } +} + +func convertStandardToFunctionParameters(parameters *ai.FunctionParameters) *FunctionParameters { + if parameters == nil { + return nil + } + + return &FunctionParameters{ + Type: parameters.Type, + Properties: convertStandardToProperty(parameters.Properties), + Required: parameters.Required, + } +} + +func convertFunctionParametersToStandard(parameters *FunctionParameters) *ai.FunctionParameters { + if parameters == nil { + return nil + } + + return &ai.FunctionParameters{ + Type: parameters.Type, + Properties: convertPropertyToStandard(parameters.Properties), + Required: parameters.Required, + } +} + +func convertStandardToProperty(properties map[string]*ai.ParameterProperty) map[string]*Property { + if properties == nil { + return nil + } + + result := make(map[string]*Property) + for k, v := range properties { + result[k] = &Property{ + Type: v.Type, + Description: v.Description, + } + } + return result +} + +func convertPropertyToStandard(properties map[string]*Property) map[string]*ai.ParameterProperty { + if properties == nil { + return nil + } + + result := make(map[string]*ai.ParameterProperty) + for k, v := range properties { + result[k] = &ai.ParameterProperty{ + Type: v.Type, + Description: v.Description, + } + } + return result +} + +// generateJSONSchemaArguments generates the JSON schema arguments from OpenAPI compatible arguments +// https://ai.google.dev/docs/function_calling#how_it_works +func generateJSONSchemaArguments(args map[string]interface{}) string { + schema := make(map[string]interface{}) + + for k, v := range args { + schema[k] = v + } + + schemaJSON, err := json.Marshal(schema) + if err != nil { + return "" + } + + return string(schemaJSON) +} + +func parseAPIResponseBody(respBody []byte) (*Response, error) { + var response *Response + err := json.Unmarshal(respBody, &response) + if err != nil { + ylog.Error("parseAPIResponseBody", "err", err, "respBody", string(respBody)) + return nil, err + } + return response, nil +} + +func parseToolCallFromResponse(response *Response) []ai.ToolCall { + calls := make([]ai.ToolCall, 0) + for _, candidate := range response.Candidates { + fn := candidate.Content.Parts[0].FunctionCall + fd := &ai.FunctionDefinition{ + Name: fn.Name, + Arguments: generateJSONSchemaArguments(fn.Args), + } + call := ai.ToolCall{ + ID: "cc-gemini-id", + Type: "cc-function", + Function: fd, + } + calls = append(calls, call) + } + return calls +} diff --git a/pkg/bridge/ai/provider/gemini/model_request.go b/pkg/bridge/ai/provider/gemini/model_request.go new file mode 100644 index 000000000..4340b6569 --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/model_request.go @@ -0,0 +1,43 @@ +package gemini + +// RequestBody is the request body +type RequestBody struct { + Contents Contents `json:"contents"` + Tools []Tool `json:"tools"` +} + +// Contents is the contents in RequestBody +type Contents struct { + Role string `json:"role"` + Parts Parts `json:"parts"` +} + +// Parts is the contents.parts in RequestBody +type Parts struct { + Text string `json:"text"` +} + +// Tool is the element of tools in RequestBody +type Tool struct { + FunctionDeclarations []*FunctionDeclaration `json:"function_declarations"` +} + +// FunctionDeclaration is the element of Tool +type FunctionDeclaration struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters *FunctionParameters `json:"parameters"` +} + +// FunctionParameters is the parameters of FunctionDeclaration +type FunctionParameters struct { + Type string `json:"type"` + Properties map[string]*Property `json:"properties"` + Required []string `json:"required"` +} + +// Property is the element of ParameterProperties +type Property struct { + Type string `json:"type"` + Description string `json:"description"` +} diff --git a/pkg/bridge/ai/provider/gemini/model_response.go b/pkg/bridge/ai/provider/gemini/model_response.go new file mode 100644 index 000000000..073c84bab --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/model_response.go @@ -0,0 +1,55 @@ +package gemini + +type Response struct { + Candidates []Candidate `json:"candidates"` + PromptFeedback PromptFeedback `json:"promptFeedback"` + // UsageMetadata UsageMetadata `json:"usageMetadata"` +} + +// Candidate is the element of Response +type Candidate struct { + Content *CandidateContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + // SafetyRatings []CandidateSafetyRating `json:"safetyRatings"` +} + +// CandidateContent is the content of Candidate +type CandidateContent struct { + Parts []*Part `json:"parts"` + Role string `json:"role"` +} + +// Part is the element of CandidateContent +type Part struct { + FunctionCall *FunctionCall `json:"functionCall"` +} + +// FunctionCall is the functionCall of Part +type FunctionCall struct { + Name string `json:"name"` + Args map[string]interface{} `json:"args"` +} + +// CandidateSafetyRating is the safetyRatings of Candidate +type CandidateSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +// UsageMetadata is the token usage in Response +type UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` +} + +// SafetyRating is the element of PromptFeedback +type SafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +// PromptFeedback is the feedback of Prompt +type PromptFeedback struct { + SafetyRatings []*SafetyRating `json:"safetyRatings"` +} diff --git a/pkg/bridge/ai/provider/gemini/model_test.go b/pkg/bridge/ai/provider/gemini/model_test.go new file mode 100644 index 000000000..1fbe45343 --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/model_test.go @@ -0,0 +1,362 @@ +package gemini + +import ( + "reflect" + "testing" + + "github.com/yomorun/yomo/ai" +) + +func TestConvertPropertyToStandard(t *testing.T) { + properties := map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + } + + expected := map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + } + + result := convertPropertyToStandard(properties) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertPropertyToStandard() = %v, want %v", result, expected) + } +} + +func TestConvertPropertyToStandard_NilInput(t *testing.T) { + result := convertPropertyToStandard(nil) + + if result != nil { + t.Errorf("convertPropertyToStandard() = %v, want %v", result, nil) + } +} + +func TestConvertFunctionParametersToStandard(t *testing.T) { + parameters := &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + } + + expected := &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + } + + result := convertFunctionParametersToStandard(parameters) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertFunctionParametersToStandard() = %v, want %v", result, expected) + } +} + +func TestConvertFunctionParametersToStandard_NilInput(t *testing.T) { + result := convertFunctionParametersToStandard(nil) + + if result != nil { + t.Errorf("convertFunctionParametersToStandard() = %v, want %v", result, nil) + } +} + +func TestConvertFunctionDeclarationToStandard(t *testing.T) { + functionDeclaration := &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + + expected := &ai.FunctionDefinition{ + Name: "function1", + Description: "desc1", + Parameters: &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + + result := convertFunctionDeclarationToStandard(functionDeclaration) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertFunctionDeclarationToStandard() = %v, want %v", result, expected) + } +} + +func TestConvertFunctionDeclarationToStandard_NilInput(t *testing.T) { + result := convertFunctionDeclarationToStandard(nil) + + if result != nil { + t.Errorf("convertFunctionDeclarationToStandard() = %v, want %v", result, nil) + } +} + +func TestConvertStandardToProperty(t *testing.T) { + properties := map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + } + + expected := map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + } + + result := convertStandardToProperty(properties) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertStandardToProperty() = %v, want %v", result, expected) + } +} + +func TestConvertStandardToProperty_NilInput(t *testing.T) { + result := convertStandardToProperty(nil) + + if result != nil { + t.Errorf("convertStandardToProperty() = %v, want %v", result, nil) + } +} + +func TestConvertStandardToFunctionParameters(t *testing.T) { + parameters := &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + } + + expected := &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + } + + result := convertStandardToFunctionParameters(parameters) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertStandardToFunctionParameters() = %v, want %v", result, expected) + } +} + +func TestConvertStandardToFunctionParameters_NilInput(t *testing.T) { + result := convertStandardToFunctionParameters(nil) + + if result != nil { + t.Errorf("convertStandardToFunctionParameters() = %v, want %v", result, nil) + } +} + +func TestConvertStandardToFunctionDeclaration(t *testing.T) { + functionDefinition := &ai.FunctionDefinition{ + Name: "function1", + Description: "desc1", + Parameters: &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + + expected := &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + + result := convertStandardToFunctionDeclaration(functionDefinition) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("convertStandardToFunctionDeclaration() = %v, want %v", result, expected) + } +} + +func TestConvertStandardToFunctionDeclaration_NilInput(t *testing.T) { + result := convertStandardToFunctionDeclaration(nil) + + if result != nil { + t.Errorf("convertStandardToFunctionDeclaration() = %v, want %v", result, nil) + } +} + +func TestGenerateJSONSchemaArguments(t *testing.T) { + args := map[string]interface{}{ + "arg1": "value1", + "arg2": "value2", + } + + expected := `{"arg1":"value1","arg2":"value2"}` + + result := generateJSONSchemaArguments(args) + + if result != expected { + t.Errorf("generateJSONSchemaArguments() = %v, want %v", result, expected) + } +} + +func TestGenerateJSONSchemaArguments_EmptyArgs(t *testing.T) { + args := map[string]interface{}{} + + expected := `{}` + + result := generateJSONSchemaArguments(args) + + if result != expected { + t.Errorf("generateJSONSchemaArguments() = %v, want %v", result, expected) + } +} + +func TestParseAPIResponseBody(t *testing.T) { + respBody := []byte(`{"candidates":[{"content":{"parts":[{"functionCall":{"name":"converter","args":{"timeString":"1900-01-01 07:00:00","targetTimezone":"Asia/Singapore","sourceTimezone":"America/Los_Angeles"}}}],"role":"model"},"finishReason":"STOP","index":0}],"promptFeedback":{"safetyRatings":[{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","probability":"NEGLIGIBLE"},{"category":"HARM_CATEGORY_HATE_SPEECH","probability":"NEGLIGIBLE"},{"category":"HARM_CATEGORY_HARASSMENT","probability":"NEGLIGIBLE"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","probability":"NEGLIGIBLE"}]}}`) + expected := &Response{ + Candidates: []Candidate{ + { + Content: &CandidateContent{ + Parts: []*Part{ + { + FunctionCall: &FunctionCall{ + Name: "converter", + Args: map[string]interface{}{ + "timeString": "1900-01-01 07:00:00", + "targetTimezone": "Asia/Singapore", + "sourceTimezone": "America/Los_Angeles", + }, + }, + }, + }, + Role: "model", + }, + FinishReason: "STOP", + Index: 0, + }, + }, + } + + result, err := parseAPIResponseBody(respBody) + if err != nil { + t.Fatalf("parseAPIResponseBody() error = %v, wantErr %v", err, false) + } + + if !reflect.DeepEqual(result.Candidates, expected.Candidates) { + t.Errorf("parseAPIResponseBody() = %v, want %v", result, expected) + } +} + +func TestParseAPIResponseBody_InvalidJSON(t *testing.T) { + respBody := []byte(`invalid json`) + + _, err := parseAPIResponseBody(respBody) + if err == nil { + t.Errorf("parseAPIResponseBody() error = %v, wantErr %v", err, true) + } +} + +func TestParseAPIResponseBody_JSON(t *testing.T) { + str := "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"functionCall\": {\n \"name\": \"converter\",\n \"args\": {\n \"timeString\": \"1900-01-01 07:00:00\",\n \"targetTimezone\": \"Asia/Singapore\",\n \"sourceTimezone\": \"America/Los_Angeles\"\n }\n }\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"promptFeedback\": {\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n}\n" + + respBody := []byte(str) + + expected := &Response{ + Candidates: []Candidate{ + { + Content: &CandidateContent{ + Parts: []*Part{ + { + FunctionCall: &FunctionCall{ + Name: "converter", + Args: map[string]interface{}{ + "timeString": "1900-01-01 07:00:00", + "targetTimezone": "Asia/Singapore", + "sourceTimezone": "America/Los_Angeles", + }, + }, + }, + }, + Role: "model", + }, + FinishReason: "STOP", + Index: 0, + }, + }, + } + + result, err := parseAPIResponseBody(respBody) + if err != nil { + t.Fatalf("parseAPIResponseBody() error = %v, wantErr %v", err, false) + } + + if !reflect.DeepEqual(result.Candidates, expected.Candidates) { + t.Errorf("parseAPIResponseBody() = %v, want %v", result, expected) + } +} + +func TestParseToolCallFromResponse(t *testing.T) { + resp := &Response{ + Candidates: []Candidate{ + { + Content: &CandidateContent{ + Parts: []*Part{ + { + FunctionCall: &FunctionCall{ + Name: "find_theaters", + Args: map[string]interface{}{ + "location": "Mountain View, CA", + "movie": "Barbie", + }, + }, + }, + }, + }, + }, + }, + } + + expected := []ai.ToolCall{ + { + Function: &ai.FunctionDefinition{ + Name: "find_theaters", + Arguments: "{\"location\":\"Mountain View, CA\",\"movie\":\"Barbie\"}", + }, + ID: "cc-gemini-id", + Type: "cc-function", + }, + } + + result := parseToolCallFromResponse(resp) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("parseToolCallFromResponse() = %v, want %v", result, expected) + } +} diff --git a/pkg/bridge/ai/provider/gemini/provider.go b/pkg/bridge/ai/provider/gemini/provider.go new file mode 100644 index 000000000..8aeb18f8c --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/provider.go @@ -0,0 +1,239 @@ +package gemini + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/ylog" + baseProvider "github.com/yomorun/yomo/pkg/bridge/ai" +) + +var fns sync.Map + +type connectedFn struct { + connID uint64 + tag uint32 + fd *FunctionDeclaration +} + +func init() { + fns = sync.Map{} +} + +// GeminiProvider is the provider for Gemini +type GeminiProvider struct { + APIKey string +} + +var _ = baseProvider.LLMProvider(&GeminiProvider{}) + +// Name returns the name of the provider +func (p *GeminiProvider) Name() string { + return "gemini" +} + +// GetChatCompletions get chat completions for ai service +func (p *GeminiProvider) GetChatCompletions(userInstruction string) (*ai.InvokeResponse, error) { + // check if there are any tool calls attached, if no, return error + isEmpty := true + fns.Range(func(_, _ interface{}) bool { + isEmpty = false + return false + }) + + if isEmpty { + ylog.Error("-----tools is empty") + return &ai.InvokeResponse{Content: "no toolCalls"}, ai.ErrNoFunctionCall + } + + // prepare request body + body := p.prepareRequestBody(userInstruction) + + // request API + jsonBody, err := json.Marshal(body) + if err != nil { + ylog.Error(err.Error()) + return nil, err + } + + ylog.Debug("gemini api request", "body", string(jsonBody)) + + req, err := http.NewRequest("POST", p.getApiUrl(), bytes.NewBuffer(jsonBody)) + if err != nil { + ylog.Error(err.Error()) + fmt.Println("Error creating new request:", err) + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + ylog.Error(err.Error()) + fmt.Println("Error making request:", err) + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + ylog.Error(err.Error()) + fmt.Println("Error reading response body:", err) + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("gemini provider api response status code is %d", resp.StatusCode) + } + + ylog.Debug("gemini api response", "body", string(respBody)) + + // parse response + response, err := parseAPIResponseBody(respBody) + if err != nil { + ylog.Error(err.Error()) + return nil, err + } + + // get all candidates as []*ai.ToolCall + calls := parseToolCallFromResponse(response) + + ylog.Debug("gemini api response", "calls", len(calls)) + + result := &ai.InvokeResponse{} + if len(calls) == 0 { + return result, ai.ErrNoFunctionCall + } + + result.ToolCalls = make(map[uint32][]*ai.ToolCall) + for _, call := range calls { + ylog.Debug("++call", "call", call.Function.Name, "call", call.Function.Arguments) + fns.Range(func(_, value interface{}) bool { + fn := value.(*connectedFn) + ylog.Debug("-->", "call.Function.Name", call.Function.Name, "fns.fd.Name", fn.fd.Name) + if call.Function.Name == fn.fd.Name { + ylog.Debug("-----> add function", "name", fn.fd.Name, "tag", fn.tag) + currentCall := call + result.ToolCalls[fn.tag] = append(result.ToolCalls[fn.tag], ¤tCall) + } + return true + }) + } + + // messages + return result, nil +} + +// RegisterFunction registers the llm function +func (p *GeminiProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error { + // replace "-" in functionDefinition.Name to "_" as gemini does not support "-" + functionDefinition.Name = strings.Replace(functionDefinition.Name, "-", "_", -1) + + fns.Store(connID, &connectedFn{ + connID: connID, + tag: tag, + fd: convertStandardToFunctionDeclaration(functionDefinition), + }) + + return nil +} + +// UnregisterFunction unregister the llm function +func (p *GeminiProvider) UnregisterFunction(name string, connID uint64) error { + fns.Delete(connID) + return nil +} + +// ListToolCalls lists the llm tool calls +func (p *GeminiProvider) ListToolCalls() (map[uint32]ai.ToolCall, error) { + result := make(map[uint32]ai.ToolCall) + + tmp := make(map[uint32]*FunctionDeclaration) + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + tmp[fn.tag] = fn.fd + result[fn.tag] = ai.ToolCall{ + Function: convertFunctionDeclarationToStandard(fn.fd), + } + return true + }) + + return result, nil +} + +// GetOverview returns the overview of the AI functions, key is the tag, value is the function definition +func (p *GeminiProvider) GetOverview() (*ai.OverviewResponse, error) { + isEmpty := true + fns.Range(func(_, _ any) bool { + isEmpty = false + return false + }) + + result := &ai.OverviewResponse{ + Functions: make(map[uint32]*ai.FunctionDefinition), + } + + if isEmpty { + return result, nil + } + + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + result.Functions[fn.tag] = convertFunctionDeclarationToStandard(fn.fd) + return true + }) + + return result, nil +} + +// getApiUrl returns the gemini generateContent API url +func (p *GeminiProvider) getApiUrl() string { + return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=%s", p.APIKey) +} + +// prepareRequestBody prepares the request body for the API +func (p *GeminiProvider) prepareRequestBody(userInstruction string) *RequestBody { + body := &RequestBody{} + + // prepare contents + body.Contents.Role = "user" + body.Contents.Parts.Text = userInstruction + + // prepare tools + toolCalls := make([]*FunctionDeclaration, 0) + fns.Range(func(_, value interface{}) bool { + fn := value.(*connectedFn) + toolCalls = append(toolCalls, fn.fd) + return true + }) + body.Tools = make([]Tool, 0) + if len(toolCalls) > 0 { + body.Tools = append(body.Tools, Tool{ + FunctionDeclarations: toolCalls, + }) + } + + return body +} + +// NewProvider creates a new GeminiProvider +func NewProvider(apiKey string) *GeminiProvider { + if apiKey == "" { + apiKey = os.Getenv("GEMINI_API_KEY") + } + p := &GeminiProvider{ + APIKey: apiKey, + } + apiURL := p.getApiUrl() + ylog.Debug("new gemini provider", "api_endpoint", apiURL) + + return p +} diff --git a/pkg/bridge/ai/provider/gemini/provider_test.go b/pkg/bridge/ai/provider/gemini/provider_test.go new file mode 100644 index 000000000..090e0ca54 --- /dev/null +++ b/pkg/bridge/ai/provider/gemini/provider_test.go @@ -0,0 +1,539 @@ +package gemini + +import ( + "encoding/json" + "errors" + "os" + "reflect" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yomorun/yomo/ai" +) + +func TestGeminiProvider_Name(t *testing.T) { + provider := &GeminiProvider{} + + name := provider.Name() + + if name != "gemini" { + t.Errorf("Name() = %v, want %v", name, "gemini") + } +} + +func TestGeminiProvider_getApiUrl(t *testing.T) { + provider := &GeminiProvider{ + APIKey: "test-api-key", + } + + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=test-api-key" + + result := provider.getApiUrl() + + if result != expected { + t.Errorf("getApiUrl() = %v, want %v", result, expected) + } +} + +func TestNewProvider(t *testing.T) { + apiKey := "test-api-key" + provider := NewProvider(apiKey) + + if provider.APIKey != apiKey { + t.Errorf("NewProvider() = %v, want %v", provider.APIKey, apiKey) + } +} + +func TestNewProviderWithEnvVar(t *testing.T) { + // Set up + expectedAPIKey := "test-api-key" + os.Setenv("GEMINI_API_KEY", expectedAPIKey) + + // Call the function under test + provider := NewProvider("") + + // Check the result + if provider.APIKey != expectedAPIKey { + t.Errorf("NewProvider() = %v, want %v", provider.APIKey, expectedAPIKey) + } +} + +func TestNewProviderWithoutEnvVar(t *testing.T) { + // Set up + os.Unsetenv("GEMINI_API_KEY") + + // Call the function under test + provider := NewProvider("") + + // Check the result + if provider.APIKey != "" { + t.Errorf("NewProvider() = %v, want %v", provider.APIKey, "") + } +} + +func TestGeminiProvider_GetOverview_Empty(t *testing.T) { + provider := &GeminiProvider{} + + result, err := provider.GetOverview() + if err != nil { + t.Errorf("GetOverview() error = %v, wantErr %v", err, nil) + return + } + + if len(result.Functions) != 0 { + t.Errorf("GetOverview() = %v, want %v", len(result.Functions), 0) + } +} + +func TestGeminiProvider_GetOverview_NotEmpty(t *testing.T) { + provider := &GeminiProvider{} + + // Add a function to the fns map + fns.Store("test", &connectedFn{ + tag: 1, + fd: &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + }, + }) + + result, err := provider.GetOverview() + if err != nil { + t.Errorf("GetOverview() error = %v, wantErr %v", err, nil) + return + } + + if len(result.Functions) != 1 { + t.Errorf("GetOverview() = %v, want %v", len(result.Functions), 1) + } +} + +func TestGeminiProvider_ListToolCalls_Empty(t *testing.T) { + fns = sync.Map{} + provider := &GeminiProvider{} + + result, err := provider.ListToolCalls() + if err != nil { + t.Errorf("ListToolCalls() error = %v, wantErr %v", err, nil) + return + } + + if len(result) != 0 { + t.Errorf("ListToolCalls() = %v, want %v", len(result), 0) + } +} + +func TestGeminiProvider_ListToolCalls_NotEmpty(t *testing.T) { + provider := &GeminiProvider{} + + // Add a function to the fns map + fns.Store("test", &connectedFn{ + tag: 1, + fd: &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + }, + }) + + result, err := provider.ListToolCalls() + if err != nil { + t.Errorf("ListToolCalls() error = %v, wantErr %v", err, nil) + return + } + + if len(result) != 1 { + t.Errorf("ListToolCalls() = %v, want %v", len(result), 1) + } + + // TearDown + fns = sync.Map{} +} + +func TestGeminiProvider_RegisterFunction(t *testing.T) { + provider := &GeminiProvider{} + tag := uint32(1) + connID := uint64(1) + functionDefinition := &ai.FunctionDefinition{ + Name: "function1", + Description: "desc1", + Parameters: &ai.FunctionParameters{ + Type: "type1", + Properties: map[string]*ai.ParameterProperty{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + } + + err := provider.RegisterFunction(tag, functionDefinition, connID) + if err != nil { + t.Errorf("RegisterFunction() error = %v, wantErr %v", err, nil) + return + } + + value, ok := fns.Load(connID) + if !ok { + t.Errorf("RegisterFunction() did not store the function correctly") + return + } + + cf := value.(*connectedFn) + if cf.connID != connID || cf.tag != tag || !reflect.DeepEqual(cf.fd, convertStandardToFunctionDeclaration(functionDefinition)) { + t.Errorf("RegisterFunction() = %v, want %v", cf, &connectedFn{ + connID: connID, + tag: tag, + fd: convertStandardToFunctionDeclaration(functionDefinition), + }) + } +} + +func TestGeminiProvider_UnregisterFunction(t *testing.T) { + provider := &GeminiProvider{} + connID := uint64(1) + + // Add a function to the fns map + fns.Store(connID, &connectedFn{ + tag: 1, + fd: &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + }, + }) + + err := provider.UnregisterFunction("function1", connID) + if err != nil { + t.Errorf("UnregisterFunction() error = %v, wantErr %v", err, nil) + return + } + + _, ok := fns.Load(connID) + if ok { + t.Errorf("UnregisterFunction() did not remove the function correctly") + } + + // TearDown + fns = sync.Map{} +} + +func TestGeminiProvider_GetChatCompletions_NoFunctions(t *testing.T) { + fns = sync.Map{} + + provider := &GeminiProvider{} + + result, err := provider.GetChatCompletions("test") + + if !errors.Is(err, ai.ErrNoFunctionCall) { + t.Errorf("GetChatCompletions() error = %v, wantErr %v", err, ai.ErrNoFunctionCall) + return + } + + if result.Content != "no toolCalls" { + t.Errorf("GetChatCompletions() = %v, want %v", result.Content, "no toolCalls") + } +} + +func TestGeminiProvider_prepareRequestBody_NilInstruction(t *testing.T) { + provider := &GeminiProvider{} + + userInstruction := "" + expected := &RequestBody{ + Contents: Contents{ + Role: "user", + Parts: Parts{ + Text: userInstruction, + }, + }, + Tools: []Tool{}, + } + + result := provider.prepareRequestBody(userInstruction) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("prepareRequestBody() = %v, want %v", result, expected) + } +} + +func TestGeminiProvider_prepareRequestBody_Cleanup(t *testing.T) { + t.Log("-------------tear down------------") + // TearDown + fns = sync.Map{} +} + +func TestGeminiProvider_prepareRequestBody_NoFunctions(t *testing.T) { + provider := &GeminiProvider{} + + userInstruction := "test instruction" + expected := &RequestBody{ + Contents: Contents{ + Role: "user", + Parts: Parts{ + Text: userInstruction, + }, + }, + Tools: []Tool{}, + } + + result := provider.prepareRequestBody(userInstruction) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("prepareRequestBody() = %v, want %v", result, expected) + } +} + +func TestGeminiProvider_prepareRequestBody_EmptyInstruction(t *testing.T) { + provider := &GeminiProvider{} + + userInstruction := "" + expected := &RequestBody{ + Contents: Contents{ + Role: "user", + Parts: Parts{ + Text: userInstruction, + }, + }, + Tools: []Tool{}, + } + + result := provider.prepareRequestBody(userInstruction) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("prepareRequestBody() = %v, want %v", result, expected) + } +} + +func TestGeminiProvider_prepareRequestBody(t *testing.T) { + provider := &GeminiProvider{} + + // Add a function to the fns map + fns.Store(uint64(1), &connectedFn{ + tag: 1, + fd: &FunctionDeclaration{ + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + }, + }) + + userInstruction := "test instruction" + expected := &RequestBody{ + Contents: Contents{ + Role: "user", + Parts: Parts{ + Text: userInstruction, + }, + }, + Tools: []Tool{ + { + FunctionDeclarations: []*FunctionDeclaration{ + { + Name: "function1", + Description: "desc1", + Parameters: &FunctionParameters{ + Type: "type1", + Properties: map[string]*Property{ + "prop1": {Type: "type1", Description: "desc1"}, + "prop2": {Type: "type2", Description: "desc2"}, + }, + Required: []string{"prop1"}, + }, + }, + }, + }, + }, + } + + result := provider.prepareRequestBody(userInstruction) + + if !reflect.DeepEqual(result, expected) { + t.Errorf("prepareRequestBody() = %v, want %v", result, expected) + } +} + +func TestGeminiProvider_prepareRequestBody_JSON(t *testing.T) { + provider := &GeminiProvider{} + + // Add a function to the fns map + fns.Store(uint64(1), &connectedFn{ + tag: 1, + fd: &FunctionDeclaration{ + Name: "find_movies", + Description: "find movie titles currently playing in theaters based on any description, genre, title words, etc.", + Parameters: &FunctionParameters{ + Type: "object", + Properties: map[string]*Property{ + "location": {Type: "string", Description: "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"}, + "description": {Type: "string", Description: "Any kind of description including category or genre, title words, attributes, etc."}, + }, + Required: []string{"description"}, + }, + }, + }) + + fns.Store(uint64(2), &connectedFn{ + tag: 2, + fd: &FunctionDeclaration{ + Name: "find_theaters", + Description: "find theaters based on location and optionally movie title which are is currently playing in theaters", + Parameters: &FunctionParameters{ + Type: "object", + Properties: map[string]*Property{ + "location": {Type: "string", Description: "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"}, + "movie": {Type: "string", Description: "Any movie title"}, + }, + Required: []string{"location"}, + }, + }, + }) + + fns.Store(uint64(3), &connectedFn{ + tag: 3, + fd: &FunctionDeclaration{ + Name: "get_showtimes", + Description: "Find the start times for movies playing in a specific theater", + Parameters: &FunctionParameters{ + Type: "object", + Properties: map[string]*Property{ + "location": {Type: "string", Description: "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"}, + "movie": {Type: "string", Description: "Any movie title"}, + "theater": {Type: "string", Description: "Name of the theater"}, + "date": {Type: "string", Description: "Date for requested showtime"}, + }, + Required: []string{"location", "movie", "theater", "date"}, + }, + }, + }) + + userInstruction := "Which theaters in Mountain View show Barbie movie?" + + expected := `{ + "contents": { + "role": "user", + "parts": { + "text": "Which theaters in Mountain View show Barbie movie?" + } + }, + "tools": [ + { + "function_declarations": [ + { + "name": "find_movies", + "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" + }, + "description": { + "type": "string", + "description": "Any kind of description including category or genre, title words, attributes, etc." + } + }, + "required": [ + "description" + ] + } + }, + { + "name": "find_theaters", + "description": "find theaters based on location and optionally movie title which are is currently playing in theaters", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" + }, + "movie": { + "type": "string", + "description": "Any movie title" + } + }, + "required": [ + "location" + ] + } + }, + { + "name": "get_showtimes", + "description": "Find the start times for movies playing in a specific theater", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" + }, + "movie": { + "type": "string", + "description": "Any movie title" + }, + "theater": { + "type": "string", + "description": "Name of the theater" + }, + "date": { + "type": "string", + "description": "Date for requested showtime" + } + }, + "required": [ + "location", + "movie", + "theater", + "date" + ] + } + } + ] + } + ] +}` + + result := provider.prepareRequestBody(userInstruction) + + jsonBody, err := json.Marshal(result) + if err != nil { + t.Errorf("Error preparing request body: %v", err) + } + + require.JSONEqf(t, expected, string(jsonBody), "prepareRequestBody() = %v, want %v", string(jsonBody), expected) + + // if string(jsonBody) != expected { + // t.Errorf("prepareRequestBody() = %v, want %v", string(jsonBody), expected) + // } +}