From 8d8e24a2b7e87a77382ccd11cc15cba8769eb7ac Mon Sep 17 00:00:00 2001 From: "C.C" Date: Wed, 28 Feb 2024 02:41:21 -0800 Subject: [PATCH] refactor: azure openai provider configuration follows azure-sdk (#734) ```yaml providers: azopenai: api_endpoint: ## e.g. https://xxx.openai.azure.com deployment_id: api_key: api_version: ## e.g. 2023-12-01-preview ``` --------- Co-authored-by: venjiang Co-authored-by: wurui --- cli/serve.go | 23 +++++------ example/10-ai/zipper.yaml | 9 +++-- pkg/bridge/ai/provider/azopenai/provider.go | 45 ++++++++++++--------- pkg/bridge/ai/test/ai_test.go | 2 +- 4 files changed, 42 insertions(+), 37 deletions(-) diff --git a/cli/serve.go b/cli/serve.go index f5419122e..463ef4048 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -121,19 +121,18 @@ var serveCmd = &cobra.Command{ func registerAIProvider(aiConfig *ai.Config) { // register the AI provider for name, provider := range aiConfig.Providers { - // register the Azure OpenAI provider - if name == "azopenai" { - apiKey := provider["api_key"] - apiEndpoint := provider["api_endpoint"] - if apiKey == "" || apiEndpoint == "" { - // log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by New()") - ai.RegisterProvider(azopenai.New()) - } else { - // log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by NewAzureOpenAIProvider()") - ai.RegisterProvider(azopenai.NewAzureOpenAIProvider(apiKey, apiEndpoint)) - } + // register LLM provider + switch name { + case "azopenai": + ai.RegisterProvider(azopenai.NewProvider( + provider["api_key"], + provider["api_endpoint"], + provider["deployment_id"], + provider["api_version"], + )) + log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name) + // TODO: register other providers } - // TODO: register other providers } } diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index 323f1e9a0..c7a1945db 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -14,13 +14,14 @@ bridge: providers: azopenai: - api_key: api_endpoint: + deployment_id: + api_key: + api_version: openai: api_key: api_endpoint: - huggingface: - model: - + gemini: + api_key: diff --git a/pkg/bridge/ai/provider/azopenai/provider.go b/pkg/bridge/ai/provider/azopenai/provider.go index c14005d5f..03a3cc001 100644 --- a/pkg/bridge/ai/provider/azopenai/provider.go +++ b/pkg/bridge/ai/provider/azopenai/provider.go @@ -16,9 +16,7 @@ import ( "github.com/yomorun/yomo/core/ylog" ) -var ( - fns sync.Map -) +var fns sync.Map // ReqMessage is the message in Request type ReqMessage struct { @@ -68,8 +66,10 @@ type RespUsage struct { // AzureOpenAIProvider is the provider for Azure OpenAI type AzureOpenAIProvider struct { - APIKey string - APIEndpoint string + APIKey string + APIEndpoint string + DeploymentID string + APIVersion string } type connectedFn struct { @@ -82,19 +82,25 @@ func init() { fns = sync.Map{} } -// NewAzureOpenAIProvider creates a new AzureOpenAIProvider -func NewAzureOpenAIProvider(apiKey string, apiEndpoint string) *AzureOpenAIProvider { - return &AzureOpenAIProvider{ - APIKey: apiKey, - APIEndpoint: apiEndpoint, +// NewProvider creates a new AzureOpenAIProvider +func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider { + if apiKey == "" { + apiKey = os.Getenv("AZURE_OPENAI_API_KEY") + } + if apiEndpoint == "" { + apiEndpoint = os.Getenv("AZURE_OPENAI_API_ENDPOINT") + } + if deploymentID == "" { + deploymentID = os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID") + } + if apiVersion == "" { + apiVersion = os.Getenv("AZURE_OPENAI_API_VERSION") } -} - -// New creates a new AzureOpenAIProvider -func New() *AzureOpenAIProvider { return &AzureOpenAIProvider{ - APIKey: os.Getenv("AZURE_OPENAI_API_KEY"), - APIEndpoint: os.Getenv("AZURE_OPENAI_API_ENDPOINT"), + APIKey: apiKey, + APIEndpoint: apiEndpoint, + DeploymentID: deploymentID, + APIVersion: apiVersion, } } @@ -124,9 +130,6 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In // prepare tools toolCalls := make([]ai.ToolCall, 0) - // for _, v := range tools { - // toolCalls = append(toolCalls, v) - // } fns.Range(func(_, value interface{}) bool { fn := value.(*connectedFn) toolCalls = append(toolCalls, fn.tc) @@ -142,7 +145,9 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In return nil, err } - req, err := http.NewRequest("POST", p.APIEndpoint, bytes.NewBuffer(jsonBody)) + url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", p.APIEndpoint, p.DeploymentID, p.APIVersion) + ylog.Debug("chat completions request", "url", url) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) if err != nil { return nil, err } diff --git a/pkg/bridge/ai/test/ai_test.go b/pkg/bridge/ai/test/ai_test.go index 9d3eb8c43..38992ffc0 100644 --- a/pkg/bridge/ai/test/ai_test.go +++ b/pkg/bridge/ai/test/ai_test.go @@ -39,7 +39,7 @@ func TestAIToolCalls(t *testing.T) { } func startAIServer() error { - ai.RegisterProvider(azopenai.New()) + ai.RegisterProvider(azopenai.NewProvider("", "", "gpt35", "2023-12-01-preview")) aiConfig := &ai.Config{ Server: ai.Server{ Addr: "localhost:6000",