Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Google Gemini API provider #2

Merged
merged 13 commits into from
Mar 6, 2024
4 changes: 4 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
)

// serveCmd represents the serve command
Expand Down Expand Up @@ -133,6 +134,9 @@ func registerAIProvider(aiConfig *ai.Config) {
ai.RegisterProvider(azopenai.NewAzureOpenAIProvider(apiKey, apiEndpoint))
}
}
if name == "gemini" {
ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))
}
// TODO: register other providers
}
}
Expand Down
3 changes: 2 additions & 1 deletion example/10-ai/sfn-timezone-calculator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Parameter struct {
}

func Description() string {
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is 'YYYY-MM-DD HH:MM:SS'. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format 'YYYY-MM-DD HH:MM:SS'. If you are not sure about the date value of `timeString`, set date value to '1900-01-01'"
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is `YYYY-MM-DD HH:MM:SS`. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format `YYYY-MM-DD HH:MM:SS`. If you are not sure about the date value of `timeString`, set date value to `1900-01-01`"
}

func InputSchema() any {
Expand Down Expand Up @@ -78,6 +78,7 @@ func handler(ctx serverless.Context) {
targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone)
if err != nil {
slog.Error("[sfn] ConvertTimezone error", "err", err)
fcCtx.WriteErrors(err)
return
}

Expand Down
4 changes: 2 additions & 2 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ bridge:
api_key:
api_endpoint:

huggingface:
model:
gemini:
api_key:

131 changes: 131 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package gemini

import (
"encoding/json"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
)

func convertStandardToFunctionDeclaration(functionDefinition *ai.FunctionDefinition) *FunctionDeclaration {
if functionDefinition == nil {
return nil
}

return &FunctionDeclaration{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertStandardToFunctionParameters(functionDefinition.Parameters),
}
}

func convertFunctionDeclarationToStandard(functionDefinition *FunctionDeclaration) *ai.FunctionDefinition {
if functionDefinition == nil {
return nil
}

return &ai.FunctionDefinition{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertFunctionParametersToStandard(functionDefinition.Parameters),
}
}

func convertStandardToFunctionParameters(parameters *ai.FunctionParameters) *FunctionParameters {
if parameters == nil {
return nil
}

return &FunctionParameters{
Type: parameters.Type,
Properties: convertStandardToProperty(parameters.Properties),
Required: parameters.Required,
}
}

func convertFunctionParametersToStandard(parameters *FunctionParameters) *ai.FunctionParameters {
if parameters == nil {
return nil
}

return &ai.FunctionParameters{
Type: parameters.Type,
Properties: convertPropertyToStandard(parameters.Properties),
Required: parameters.Required,
}
}

func convertStandardToProperty(properties map[string]*ai.ParameterProperty) map[string]*Property {
if properties == nil {
return nil
}

result := make(map[string]*Property)
for k, v := range properties {
result[k] = &Property{
Type: v.Type,
Description: v.Description,
}
}
return result
}

func convertPropertyToStandard(properties map[string]*Property) map[string]*ai.ParameterProperty {
if properties == nil {
return nil
}

result := make(map[string]*ai.ParameterProperty)
for k, v := range properties {
result[k] = &ai.ParameterProperty{
Type: v.Type,
Description: v.Description,
}
}
return result
}

// generateJSONSchemaArguments generates the JSON schema arguments from OpenAPI compatible arguments
// https://ai.google.dev/docs/function_calling#how_it_works
func generateJSONSchemaArguments(args map[string]interface{}) string {
schema := make(map[string]interface{})

for k, v := range args {
schema[k] = v
}

schemaJSON, err := json.Marshal(schema)
if err != nil {
return ""
}

return string(schemaJSON)
}

func parseAPIResponseBody(respBody []byte) (*Response, error) {
var response *Response
err := json.Unmarshal(respBody, &response)
if err != nil {
ylog.Error("parseAPIResponseBody", "err", err, "respBody", string(respBody))
return nil, err
}
return response, nil
}

func parseToolCallFromResponse(response *Response) []ai.ToolCall {
calls := make([]ai.ToolCall, 0)
for _, candidate := range response.Candidates {
fn := candidate.Content.Parts[0].FunctionCall
fd := &ai.FunctionDefinition{
Name: fn.Name,
Arguments: generateJSONSchemaArguments(fn.Args),
}
call := ai.ToolCall{
ID: "cc-gemini-id",
Type: "cc-function",
Function: fd,
}
calls = append(calls, call)
}
return calls
}
43 changes: 43 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package gemini

// RequestBody is the request body
type RequestBody struct {
Contents Contents `json:"contents"`
Tools []Tool `json:"tools"`
}

// Contents is the contents in RequestBody
type Contents struct {
Role string `json:"role"`
Parts Parts `json:"parts"`
}

// Parts is the contents.parts in RequestBody
type Parts struct {
Text string `json:"text"`
}

// Tool is the element of tools in RequestBody
type Tool struct {
FunctionDeclarations []*FunctionDeclaration `json:"function_declarations"`
}

// FunctionDeclaration is the element of Tool
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters *FunctionParameters `json:"parameters"`
}

// FunctionParameters is the parameters of FunctionDeclaration
type FunctionParameters struct {
Type string `json:"type"`
Properties map[string]*Property `json:"properties"`
Required []string `json:"required"`
}

// Property is the element of ParameterProperties
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
55 changes: 55 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package gemini

type Response struct {
Candidates []Candidate `json:"candidates"`
PromptFeedback PromptFeedback `json:"promptFeedback"`
// UsageMetadata UsageMetadata `json:"usageMetadata"`
}

// Candidate is the element of Response
type Candidate struct {
Content *CandidateContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
// SafetyRatings []CandidateSafetyRating `json:"safetyRatings"`
}

// CandidateContent is the content of Candidate
type CandidateContent struct {
Parts []*Part `json:"parts"`
Role string `json:"role"`
}

// Part is the element of CandidateContent
type Part struct {
FunctionCall *FunctionCall `json:"functionCall"`
}

// FunctionCall is the functionCall of Part
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}

// CandidateSafetyRating is the safetyRatings of Candidate
type CandidateSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// UsageMetadata is the token usage in Response
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

// SafetyRating is the element of PromptFeedback
type SafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// PromptFeedback is the feedback of Prompt
type PromptFeedback struct {
SafetyRatings []*SafetyRating `json:"safetyRatings"`
}
Loading
Loading