diff --git a/cli/serve.go b/cli/serve.go index d9b6ada73..effe3f8ff 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -122,24 +122,22 @@ var serveCmd = &cobra.Command{ func registerAIProvider(aiConfig *ai.Config) { // register the AI provider for name, provider := range aiConfig.Providers { - // register the Azure OpenAI provider - if name == "azopenai" { - apiKey := provider["api_key"] - apiEndpoint := provider["api_endpoint"] - if apiKey == "" || apiEndpoint == "" { - // log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by New()") - ai.RegisterProvider(azopenai.New()) - } else { - // log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by NewAzureOpenAIProvider()") - ai.RegisterProvider(azopenai.NewAzureOpenAIProvider(apiKey, apiEndpoint)) - } + // register LLM provider + switch name { + case "azopenai": + ai.RegisterProvider(azopenai.NewProvider( + provider["api_key"], + provider["api_endpoint"], + provider["deployment_id"], + provider["api_version"], + )) + log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name) + // TODO: register other providers } // register the OpenAI provider if name == "openai" { ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) } - - // TODO: register other providers } } diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index dc0fd0590..ddbaef0b6 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -14,13 +14,14 @@ bridge: providers: azopenai: - api_key: api_endpoint: + deployment_id: + api_key: + api_version: openai: api_key: model: - huggingface: - model: - + gemini: + api_key: diff --git a/pkg/bridge/ai/provider/azopenai/provider.go b/pkg/bridge/ai/provider/azopenai/provider.go index dc929e41f..03a3cc001 100644 --- a/pkg/bridge/ai/provider/azopenai/provider.go +++ b/pkg/bridge/ai/provider/azopenai/provider.go @@ -66,8 +66,10 @@ type RespUsage struct { // AzureOpenAIProvider is the provider for Azure OpenAI type AzureOpenAIProvider struct { - APIKey string - APIEndpoint string + APIKey string + APIEndpoint string + DeploymentID string + APIVersion string } type connectedFn struct { @@ -80,19 +82,25 @@ func init() { fns = sync.Map{} } -// NewAzureOpenAIProvider creates a new AzureOpenAIProvider -func NewAzureOpenAIProvider(apiKey string, apiEndpoint string) *AzureOpenAIProvider { - return &AzureOpenAIProvider{ - APIKey: apiKey, - APIEndpoint: apiEndpoint, +// NewProvider creates a new AzureOpenAIProvider +func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider { + if apiKey == "" { + apiKey = os.Getenv("AZURE_OPENAI_API_KEY") + } + if apiEndpoint == "" { + apiEndpoint = os.Getenv("AZURE_OPENAI_API_ENDPOINT") + } + if deploymentID == "" { + deploymentID = os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID") + } + if apiVersion == "" { + apiVersion = os.Getenv("AZURE_OPENAI_API_VERSION") } -} - -// New creates a new AzureOpenAIProvider -func New() *AzureOpenAIProvider { return &AzureOpenAIProvider{ - APIKey: os.Getenv("AZURE_OPENAI_API_KEY"), - APIEndpoint: os.Getenv("AZURE_OPENAI_API_ENDPOINT"), + APIKey: apiKey, + APIEndpoint: apiEndpoint, + DeploymentID: deploymentID, + APIVersion: apiVersion, } } @@ -122,9 +130,6 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In // prepare tools toolCalls := make([]ai.ToolCall, 0) - // for _, v := range tools { - // toolCalls = append(toolCalls, v) - // } fns.Range(func(_, value interface{}) bool { fn := value.(*connectedFn) toolCalls = append(toolCalls, fn.tc) @@ -140,7 +145,9 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In return nil, err } - req, err := http.NewRequest("POST", p.APIEndpoint, bytes.NewBuffer(jsonBody)) + url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", p.APIEndpoint, p.DeploymentID, p.APIVersion) + ylog.Debug("chat completions request", "url", url) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) if err != nil { return nil, err } diff --git a/pkg/bridge/ai/test/ai_test.go b/pkg/bridge/ai/test/ai_test.go index 9d3eb8c43..38992ffc0 100644 --- a/pkg/bridge/ai/test/ai_test.go +++ b/pkg/bridge/ai/test/ai_test.go @@ -39,7 +39,7 @@ func TestAIToolCalls(t *testing.T) { } func startAIServer() error { - ai.RegisterProvider(azopenai.New()) + ai.RegisterProvider(azopenai.NewProvider("", "", "gpt35", "2023-12-01-preview")) aiConfig := &ai.Config{ Server: ai.Server{ Addr: "localhost:6000",