Skip to content

Commit

Permalink
refactor: azure openai provider configuration follows azure-sdk (yomo…
Browse files Browse the repository at this point in the history
…run#734)

```yaml
providers:
      azopenai:
        api_endpoint: ## e.g. https://xxx.openai.azure.com
        deployment_id: 
        api_key: 
        api_version: ## e.g. 2023-12-01-preview
```

---------

Co-authored-by: venjiang <[email protected]>
Co-authored-by: wurui <[email protected]>
  • Loading branch information
3 people authored Feb 28, 2024
1 parent be34e02 commit 8d8e24a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 37 deletions.
23 changes: 11 additions & 12 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,18 @@ var serveCmd = &cobra.Command{
func registerAIProvider(aiConfig *ai.Config) {
// register the AI provider
for name, provider := range aiConfig.Providers {
// register the Azure OpenAI provider
if name == "azopenai" {
apiKey := provider["api_key"]
apiEndpoint := provider["api_endpoint"]
if apiKey == "" || apiEndpoint == "" {
// log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by New()")
ai.RegisterProvider(azopenai.New())
} else {
// log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by NewAzureOpenAIProvider()")
ai.RegisterProvider(azopenai.NewAzureOpenAIProvider(apiKey, apiEndpoint))
}
// register LLM provider
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name)
// TODO: register other providers
}
// TODO: register other providers
}
}

Expand Down
9 changes: 5 additions & 4 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ bridge:

providers:
azopenai:
api_key:
api_endpoint:
deployment_id:
api_key:
api_version:

openai:
api_key:
api_endpoint:

huggingface:
model:

gemini:
api_key:
45 changes: 25 additions & 20 deletions pkg/bridge/ai/provider/azopenai/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import (
"github.com/yomorun/yomo/core/ylog"
)

var (
fns sync.Map
)
var fns sync.Map

// ReqMessage is the message in Request
type ReqMessage struct {
Expand Down Expand Up @@ -68,8 +66,10 @@ type RespUsage struct {

// AzureOpenAIProvider is the provider for Azure OpenAI
type AzureOpenAIProvider struct {
APIKey string
APIEndpoint string
APIKey string
APIEndpoint string
DeploymentID string
APIVersion string
}

type connectedFn struct {
Expand All @@ -82,19 +82,25 @@ func init() {
fns = sync.Map{}
}

// NewAzureOpenAIProvider creates a new AzureOpenAIProvider
func NewAzureOpenAIProvider(apiKey string, apiEndpoint string) *AzureOpenAIProvider {
return &AzureOpenAIProvider{
APIKey: apiKey,
APIEndpoint: apiEndpoint,
// NewProvider creates a new AzureOpenAIProvider
func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider {
if apiKey == "" {
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
}
if apiEndpoint == "" {
apiEndpoint = os.Getenv("AZURE_OPENAI_API_ENDPOINT")
}
if deploymentID == "" {
deploymentID = os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID")
}
if apiVersion == "" {
apiVersion = os.Getenv("AZURE_OPENAI_API_VERSION")
}
}

// New creates a new AzureOpenAIProvider
func New() *AzureOpenAIProvider {
return &AzureOpenAIProvider{
APIKey: os.Getenv("AZURE_OPENAI_API_KEY"),
APIEndpoint: os.Getenv("AZURE_OPENAI_API_ENDPOINT"),
APIKey: apiKey,
APIEndpoint: apiEndpoint,
DeploymentID: deploymentID,
APIVersion: apiVersion,
}
}

Expand Down Expand Up @@ -124,9 +130,6 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In

// prepare tools
toolCalls := make([]ai.ToolCall, 0)
// for _, v := range tools {
// toolCalls = append(toolCalls, v)
// }
fns.Range(func(_, value interface{}) bool {
fn := value.(*connectedFn)
toolCalls = append(toolCalls, fn.tc)
Expand All @@ -142,7 +145,9 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In
return nil, err
}

req, err := http.NewRequest("POST", p.APIEndpoint, bytes.NewBuffer(jsonBody))
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", p.APIEndpoint, p.DeploymentID, p.APIVersion)
ylog.Debug("chat completions request", "url", url)
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/test/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestAIToolCalls(t *testing.T) {
}

func startAIServer() error {
ai.RegisterProvider(azopenai.New())
ai.RegisterProvider(azopenai.NewProvider("", "", "gpt35", "2023-12-01-preview"))
aiConfig := &ai.Config{
Server: ai.Server{
Addr: "localhost:6000",
Expand Down

0 comments on commit 8d8e24a

Please sign in to comment.