Skip to content

Commit

Permalink
add Cloudflare OpenAI gateway support
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao committed Mar 22, 2024
1 parent 7f08807 commit 689d510
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)
Expand Down Expand Up @@ -138,6 +139,12 @@ func registerAIProvider(aiConfig *ai.Config) error {
provider["deployment_id"],
provider["api_version"],
))
case "cloudflare_openai":
ai.RegisterProvider(cfopenai.NewProvider(
provider["endpoint"],
provider["api_key"],
provider["model"],
))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
Expand Down
75 changes: 75 additions & 0 deletions pkg/bridge/ai/provider/cfopenai/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Package cfopenai is used to provide the Azure OpenAI service
package cfopenai

import (
"fmt"
"os"

// automatically load .env file
_ "github.com/joho/godotenv/autoload"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/core/ylog"
bridgeai "github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/internal/openai"
)

// CloudflareOpenaiProvider is the provider for Cloudflare OpenAI Gateway
type CloudflareOpenAIProvider struct {
// CfEndpoint is the your cloudflare endpoint
CfEndpoint string
// APIKey is the API key for OpenAI
APIKey string
// Model is the model for OpenAI
Model string
}

// check if implements ai.Provider
var _ bridgeai.LLMProvider = &CloudflareOpenAIProvider{}

// NewProvider creates a new AzureOpenAIProvider
func NewProvider(cfEndpoint, apiKey, model string) *CloudflareOpenAIProvider {
if apiKey == "" {
apiKey = os.Getenv("OPENAI_API_KEY")
}
if model == "" {
model = os.Getenv("OPENAI_MODEL")
}
if cfEndpoint == "" {
ylog.Error("cfEndpoint is required")
os.Exit(-1)
}
ylog.Debug("new cloudflare openai provider", "api_key", apiKey, "model", model, "cloudflare_endpoint", cfEndpoint)
return &CloudflareOpenAIProvider{
CfEndpoint: cfEndpoint,
APIKey: apiKey,
Model: model,
}
}

// Name returns the name of the provider
func (p *CloudflareOpenAIProvider) Name() string {
return "cloudflare_openai"
}

// GetChatCompletions get chat completions for ai service
func (p *CloudflareOpenAIProvider) GetChatCompletions(userInstruction string, baseSystemMessage string, chainMessage ai.ChainMessage, md metadata.M, withTool bool) (*ai.InvokeResponse, error) {
reqBody := openai.ReqBody{Model: p.Model}

url := fmt.Sprintf("%s/openai/chat/completions", p.CfEndpoint)

res, err := openai.ChatCompletion(
url,
"Authorization",
fmt.Sprintf("Bearer %s", p.APIKey),
reqBody,
baseSystemMessage,
userInstruction,
chainMessage,
md,
withTool,
)

return res, err
}

0 comments on commit 689d510

Please sign in to comment.