Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao committed Feb 28, 2024
2 parents c5963d0 + 8d8e24a commit 882354c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 35 deletions.
24 changes: 11 additions & 13 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,22 @@ var serveCmd = &cobra.Command{
func registerAIProvider(aiConfig *ai.Config) {
// register the AI provider
for name, provider := range aiConfig.Providers {
// register the Azure OpenAI provider
if name == "azopenai" {
apiKey := provider["api_key"]
apiEndpoint := provider["api_endpoint"]
if apiKey == "" || apiEndpoint == "" {
// log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by New()")
ai.RegisterProvider(azopenai.New())
} else {
// log.InfoStatusEvent(os.Stdout, "register Azure OpenAI provider used by NewAzureOpenAIProvider()")
ai.RegisterProvider(azopenai.NewAzureOpenAIProvider(apiKey, apiEndpoint))
}
// register LLM provider
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name)
// TODO: register other providers
}
// register the OpenAI provider
if name == "openai" {
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
}

// TODO: register other providers
}
}

Expand Down
9 changes: 5 additions & 4 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ bridge:

providers:
azopenai:
api_key:
api_endpoint:
deployment_id:
api_key:
api_version:

openai:
api_key:
model:

huggingface:
model:

gemini:
api_key:
41 changes: 24 additions & 17 deletions pkg/bridge/ai/provider/azopenai/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ type RespUsage struct {

// AzureOpenAIProvider is the provider for Azure OpenAI
type AzureOpenAIProvider struct {
APIKey string
APIEndpoint string
APIKey string
APIEndpoint string
DeploymentID string
APIVersion string
}

type connectedFn struct {
Expand All @@ -80,19 +82,25 @@ func init() {
fns = sync.Map{}
}

// NewAzureOpenAIProvider creates a new AzureOpenAIProvider
func NewAzureOpenAIProvider(apiKey string, apiEndpoint string) *AzureOpenAIProvider {
return &AzureOpenAIProvider{
APIKey: apiKey,
APIEndpoint: apiEndpoint,
// NewProvider creates a new AzureOpenAIProvider
func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider {
if apiKey == "" {
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
}
if apiEndpoint == "" {
apiEndpoint = os.Getenv("AZURE_OPENAI_API_ENDPOINT")
}
if deploymentID == "" {
deploymentID = os.Getenv("AZURE_OPENAI_DEPLOYMENT_ID")
}
if apiVersion == "" {
apiVersion = os.Getenv("AZURE_OPENAI_API_VERSION")
}
}

// New creates a new AzureOpenAIProvider
func New() *AzureOpenAIProvider {
return &AzureOpenAIProvider{
APIKey: os.Getenv("AZURE_OPENAI_API_KEY"),
APIEndpoint: os.Getenv("AZURE_OPENAI_API_ENDPOINT"),
APIKey: apiKey,
APIEndpoint: apiEndpoint,
DeploymentID: deploymentID,
APIVersion: apiVersion,
}
}

Expand Down Expand Up @@ -122,9 +130,6 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In

// prepare tools
toolCalls := make([]ai.ToolCall, 0)
// for _, v := range tools {
// toolCalls = append(toolCalls, v)
// }
fns.Range(func(_, value interface{}) bool {
fn := value.(*connectedFn)
toolCalls = append(toolCalls, fn.tc)
Expand All @@ -140,7 +145,9 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In
return nil, err
}

req, err := http.NewRequest("POST", p.APIEndpoint, bytes.NewBuffer(jsonBody))
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", p.APIEndpoint, p.DeploymentID, p.APIVersion)
ylog.Debug("chat completions request", "url", url)
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/test/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestAIToolCalls(t *testing.T) {
}

func startAIServer() error {
ai.RegisterProvider(azopenai.New())
ai.RegisterProvider(azopenai.NewProvider("", "", "gpt35", "2023-12-01-preview"))
aiConfig := &ai.Config{
Server: ai.Server{
Addr: "localhost:6000",
Expand Down

0 comments on commit 882354c

Please sign in to comment.