Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Cloudflare AI Gateway for Azure OpenAI provider #3

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ bridge:

providers:
azopenai:
api_key: <YOUR_AZURE_OPENAI_API_KEY>
api_endpoint: <YOUR_AZURE_OPENAI_ENDPOINT>
api_endpoint: https://<RESOURCE>.openai.azure.com
deployment_id: <DEPLOYMENT_ID>
api_key: <API_KEY>
api_version: <API_VERSION>

openai:
api_key: <OPENAI_API_KEY>
Expand All @@ -67,8 +69,12 @@ bridge:
gemini:
api_key: <GEMINI_API_KEY>

huggingface:
model:
cloudflare_azure:
endpoint: https://gateway.ai.cloudflare.com/v1/<CF_GATEWAY_ID>/<CF_GATEWAY_NAME>
api_key: <AZURE_API_KEY>
resource: <AZURE_OPENAI_RESOURCE>
deployment_id: <AZURE_OPENAI_DEPLOYMENT_ID>
api_version: 2023-12-01-preview
```

Start the server:
Expand Down
12 changes: 11 additions & 1 deletion cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)
Expand Down Expand Up @@ -127,11 +128,20 @@ func registerAIProvider(aiConfig *ai.Config) error {
ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
case "cloudflare_azure":
ai.RegisterProvider(cfazure.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["resource"],
provider["deployment_id"],
provider["api_version"],
))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)

log.InfoStatusEvent(os.Stdout, "registered [%d] AI provider", len(ai.ListProviders()))
return nil
}

Expand Down
8 changes: 8 additions & 0 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ bridge:

gemini:
api_key:

cloudflare_azure:
endpoint:
api_key:
resource:
deployment_id:
api_version:

22 changes: 14 additions & 8 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,32 @@ func ConnMiddleware(next core.ConnHandler) core.ConnHandler {

// Config is the configuration of AI bridge.
// The configuration looks like:
//
// bridge:
//
// ai:
// server:
// host: http://localhost
// port: 8000
// credential: token:<CREDENTIAL>
// provider: azopenai
//
// provider: openai
// providers:
// azopenai:
// api_key:
// api_endpoint:
//
// api_endpoint: https://<RESOURCE>.openai.azure.com
// deployment_id: <DEPLOYMENT_ID>
// api_key: <API_KEY>
// api_version: <API_VERSION>
// openai:
// api_key:
// api_endpoint:
//
// huggingface:
// model:
// gemini:
// api_key:
// cloudflare_azure:
// endpoint: https://gateway.ai.cloudflare.com/v1/<CF_GATEWAY_ID>/<CF_GATEWAY_NAME>
// api_key: <AZURE_API_KEY>
// resource: <AZURE_OPENAI_RESOURCE>
// deployment_id: <AZURE_OPENAI_DEPLOYMENT_ID>
// api_version: <AZURE_OPENAI_API_VERSION>
type Config struct {
Server Server `yaml:"server"` // Server is the configuration of the BasicAPIServer
Providers map[string]Provider `yaml:"providers"` // Providers is the configuration of llm provider
Expand Down
154 changes: 154 additions & 0 deletions pkg/bridge/ai/provider/cfazure/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Package cfazure is used to provide the Azure OpenAI service
package cfazure

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"

// automatically load .env file
_ "github.com/joho/godotenv/autoload"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/core/ylog"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
)

// CloudflareAzureProvider is the provider for Azure OpenAI
type CloudflareAzureProvider struct {
APIKey string
Resource string
DeploymentID string
APIVersion string
CfEndpoint string
}

// NewProvider creates a new AzureOpenAIProvider
func NewProvider(cfEndpoint string, apiKey string, resource string, deploymentID string, apiVersion string) *CloudflareAzureProvider {
if cfEndpoint == "" || apiKey == "" || resource == "" || deploymentID == "" || apiVersion == "" {
ylog.Error("parameters are required", "cfEndpoint", cfEndpoint, "apiKey", apiKey, "resource", resource, "deploymentID", deploymentID, "apiVersion", apiVersion)
os.Exit(-1)
return nil
}
ylog.Debug("CloudflareAzureProvider", "cfEndpoint", cfEndpoint, "apiKey", apiKey, "resource", resource, "deploymentID", deploymentID, "apiVersion", apiVersion)
return &CloudflareAzureProvider{
CfEndpoint: cfEndpoint, // https://gateway.ai.cloudflare.com/v1/111111111111111111/ai-cc-test
APIKey: apiKey, // azure api key
Resource: resource, // azure resource
DeploymentID: deploymentID, // azure deployment id
APIVersion: apiVersion, // azure api version
}
}

// Name returns the name of the provider
func (p *CloudflareAzureProvider) Name() string {
return "cloudflare_azure"
}

// GetChatCompletions get chat completions for ai service
func (p *CloudflareAzureProvider) GetChatCompletions(userInstruction string, md metadata.M) (*ai.InvokeResponse, error) {
tcs, err := register.ListToolCalls(md)
if err != nil {
return nil, err
}

if len(tcs) == 0 {
ylog.Error("tools is empty")
return &ai.InvokeResponse{Content: "no toolcalls"}, ai.ErrNoFunctionCall
}

// messages
messages := []azopenai.ReqMessage{
{Role: "system", Content: `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous. If you don't know the answer, stop the conversation by saying "no func call".`},
{Role: "user", Content: userInstruction},
}

// prepare tools
toolCalls := make([]ai.ToolCall, len(tcs))
idx := 0
for _, tc := range tcs {
toolCalls[idx] = tc
idx++
}

body := azopenai.ReqBody{Messages: messages, Tools: toolCalls}

ylog.Debug("request", "tools", len(toolCalls), "messages", messages)

jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}

url := fmt.Sprintf("%s/azure-openai/%s/%s/chat/completions?api-version=%s", p.CfEndpoint, p.Resource, p.DeploymentID, p.APIVersion)
ylog.Debug("chat completions request", "url", url)
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", p.APIKey)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
ylog.Debug("response", "body", respBody)

if resp.StatusCode >= 400 {
return nil, fmt.Errorf("ai response status code is %d", resp.StatusCode)
}

var respBodyStruct azopenai.RespBody
err = json.Unmarshal(respBody, &respBodyStruct)
if err != nil {
return nil, err
}

choice := respBodyStruct.Choices[0]
ylog.Debug(">>finish_reason", "reason", choice.FinishReason)

calls := respBodyStruct.Choices[0].Message.ToolCalls
content := respBodyStruct.Choices[0].Message.Content

ylog.Debug("--response calls", "calls", calls)

result := &ai.InvokeResponse{}
if len(calls) == 0 {
result.Content = content
return result, ai.ErrNoFunctionCall
}

// functions may be more than one
for _, call := range calls {
for tag, tc := range tcs {
if tc.Equal(&call) {
// Use toolCalls because tool_id is required in the following llm request
if result.ToolCalls == nil {
result.ToolCalls = make(map[uint32][]*ai.ToolCall)
}

currentCall := call
result.ToolCalls[tag] = append(result.ToolCalls[tag], &currentCall)
}
}
}

// sfn maybe disconnected, so we need to check if there is any function call
if len(result.ToolCalls) == 0 {
return nil, ai.ErrNoFunctionCall
}
return result, nil
}
Loading