Skip to content

Commit

Permalink
feat: add Google Gemini API provider (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao committed Mar 6, 2024
1 parent b5317dc commit eb8f7e6
Show file tree
Hide file tree
Showing 8 changed files with 1,403 additions and 21 deletions.
52 changes: 32 additions & 20 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)

Expand Down Expand Up @@ -112,26 +113,37 @@ var serveCmd = &cobra.Command{
},
}

func registerAIProvider(aiConfig *ai.Config) {
// register the AI provider
for name, provider := range aiConfig.Providers {
// register LLM provider
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name)
// TODO: register other providers
}
// register the OpenAI provider
if name == "openai" {
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
}
}
func registerAIProvider(aiConfig *ai.Config) error {
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
err := ai.RegisterProvider(azopenai.NewProvider(

Check failure on line 120 in cli/serve.go

View workflow job for this annotation

GitHub Actions / build

ai.RegisterProvider(azopenai.NewProvider(provider["api_key"], provider["api_endpoint"], provider["deployment_id"], provider["api_version"])) (no value) used as value
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
if err != nil {
return fmt.Errorf("failed to register azopenai provider: %w", err)
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
case "gemini":
err := ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))

Check failure on line 131 in cli/serve.go

View workflow job for this annotation

GitHub Actions / build

ai.RegisterProvider(gemini.NewProvider(provider["api_key"])) (no value) used as value
if err != nil {
return fmt.Errorf("failed to register gemini provider: %w", err)
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
case "openai":
err := ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))

Check failure on line 137 in cli/serve.go

View workflow job for this annotation

GitHub Actions / build

ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) (no value) used as value
if err != nil {
return fmt.Errorf("failed to register openai provider: %w", err)
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
default:
log.Warnf("unknown provider: %s", name)

Check failure on line 143 in cli/serve.go

View workflow job for this annotation

GitHub Actions / build

undefined: log.Warnf
}
}
return nil
}

func init() {
Expand Down
3 changes: 2 additions & 1 deletion example/10-ai/sfn-timezone-calculator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Parameter struct {
}

func Description() string {
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is 'YYYY-MM-DD HH:MM:SS'. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format 'YYYY-MM-DD HH:MM:SS'. If you are not sure about the date value of `timeString`, set date value to '1900-01-01'"
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is `YYYY-MM-DD HH:MM:SS`. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format `YYYY-MM-DD HH:MM:SS`. If you are not sure about the date value of `timeString`, set date value to `1900-01-01`"
}

func InputSchema() any {
Expand Down Expand Up @@ -78,6 +78,7 @@ func handler(ctx serverless.Context) {
targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone)
if err != nil {
slog.Error("[sfn] ConvertTimezone error", "err", err)
fcCtx.WriteErrors(err)
return
}

Expand Down
131 changes: 131 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package gemini

import (
"encoding/json"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
)

func convertStandardToFunctionDeclaration(functionDefinition *ai.FunctionDefinition) *FunctionDeclaration {
if functionDefinition == nil {
return nil
}

return &FunctionDeclaration{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertStandardToFunctionParameters(functionDefinition.Parameters),
}
}

func convertFunctionDeclarationToStandard(functionDefinition *FunctionDeclaration) *ai.FunctionDefinition {
if functionDefinition == nil {
return nil
}

return &ai.FunctionDefinition{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertFunctionParametersToStandard(functionDefinition.Parameters),
}
}

func convertStandardToFunctionParameters(parameters *ai.FunctionParameters) *FunctionParameters {
if parameters == nil {
return nil
}

return &FunctionParameters{
Type: parameters.Type,
Properties: convertStandardToProperty(parameters.Properties),
Required: parameters.Required,
}
}

func convertFunctionParametersToStandard(parameters *FunctionParameters) *ai.FunctionParameters {
if parameters == nil {
return nil
}

return &ai.FunctionParameters{
Type: parameters.Type,
Properties: convertPropertyToStandard(parameters.Properties),
Required: parameters.Required,
}
}

func convertStandardToProperty(properties map[string]*ai.ParameterProperty) map[string]*Property {
if properties == nil {
return nil
}

result := make(map[string]*Property)
for k, v := range properties {
result[k] = &Property{
Type: v.Type,
Description: v.Description,
}
}
return result
}

func convertPropertyToStandard(properties map[string]*Property) map[string]*ai.ParameterProperty {
if properties == nil {
return nil
}

result := make(map[string]*ai.ParameterProperty)
for k, v := range properties {
result[k] = &ai.ParameterProperty{
Type: v.Type,
Description: v.Description,
}
}
return result
}

// generateJSONSchemaArguments generates the JSON schema arguments from OpenAPI compatible arguments
// https://ai.google.dev/docs/function_calling#how_it_works
func generateJSONSchemaArguments(args map[string]interface{}) string {
schema := make(map[string]interface{})

for k, v := range args {
schema[k] = v
}

schemaJSON, err := json.Marshal(schema)
if err != nil {
return ""
}

return string(schemaJSON)
}

func parseAPIResponseBody(respBody []byte) (*Response, error) {
var response *Response
err := json.Unmarshal(respBody, &response)
if err != nil {
ylog.Error("parseAPIResponseBody", "err", err, "respBody", string(respBody))
return nil, err
}
return response, nil
}

func parseToolCallFromResponse(response *Response) []ai.ToolCall {
calls := make([]ai.ToolCall, 0)
for _, candidate := range response.Candidates {
fn := candidate.Content.Parts[0].FunctionCall
fd := &ai.FunctionDefinition{
Name: fn.Name,
Arguments: generateJSONSchemaArguments(fn.Args),
}
call := ai.ToolCall{
ID: "cc-gemini-id",
Type: "cc-function",
Function: fd,
}
calls = append(calls, call)
}
return calls
}
43 changes: 43 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package gemini

// RequestBody is the request body
type RequestBody struct {
Contents Contents `json:"contents"`
Tools []Tool `json:"tools"`
}

// Contents is the contents in RequestBody
type Contents struct {
Role string `json:"role"`
Parts Parts `json:"parts"`
}

// Parts is the contents.parts in RequestBody
type Parts struct {
Text string `json:"text"`
}

// Tool is the element of tools in RequestBody
type Tool struct {
FunctionDeclarations []*FunctionDeclaration `json:"function_declarations"`
}

// FunctionDeclaration is the element of Tool
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters *FunctionParameters `json:"parameters"`
}

// FunctionParameters is the parameters of FunctionDeclaration
type FunctionParameters struct {
Type string `json:"type"`
Properties map[string]*Property `json:"properties"`
Required []string `json:"required"`
}

// Property is the element of ParameterProperties
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
55 changes: 55 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package gemini

type Response struct {
Candidates []Candidate `json:"candidates"`
PromptFeedback PromptFeedback `json:"promptFeedback"`
// UsageMetadata UsageMetadata `json:"usageMetadata"`
}

// Candidate is the element of Response
type Candidate struct {
Content *CandidateContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
// SafetyRatings []CandidateSafetyRating `json:"safetyRatings"`
}

// CandidateContent is the content of Candidate
type CandidateContent struct {
Parts []*Part `json:"parts"`
Role string `json:"role"`
}

// Part is the element of CandidateContent
type Part struct {
FunctionCall *FunctionCall `json:"functionCall"`
}

// FunctionCall is the functionCall of Part
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}

// CandidateSafetyRating is the safetyRatings of Candidate
type CandidateSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// UsageMetadata is the token usage in Response
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

// SafetyRating is the element of PromptFeedback
type SafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// PromptFeedback is the feedback of Prompt
type PromptFeedback struct {
SafetyRatings []*SafetyRating `json:"safetyRatings"`
}
Loading

0 comments on commit eb8f7e6

Please sign in to comment.