Skip to content

Commit

Permalink
azure ai support (#997)
Browse files Browse the repository at this point in the history
  • Loading branch information
sawka authored Oct 9, 2024
1 parent b81ab63 commit f33028a
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 68 deletions.
9 changes: 6 additions & 3 deletions frontend/app/view/waveai/waveai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel {
};
addMessage(newMessage);
// send message to backend and get response
const settings = globalStore.get(atoms.settingsAtom);
const settings = globalStore.get(atoms.settingsAtom) ?? {};
const opts: OpenAIOptsType = {
model: settings["ai:model"],
apitype: settings["ai:apitype"],
orgid: settings["ai:orgid"],
apitoken: settings["ai:apitoken"],
apiversion: settings["ai:apiversion"],
maxtokens: settings["ai:maxtokens"],
timeout: settings["ai:timeoutms"] / 1000,
timeoutms: settings["ai:timeoutms"] ?? 60000,
baseurl: settings["ai:baseurl"],
};
const newPrompt: OpenAIPromptMessageType = {
Expand All @@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel {
opts: opts,
prompt: [...history, newPrompt],
};
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 });
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
let fullMsg = "";
for await (const msg of aiGen) {
fullMsg += msg.text ?? "";
Expand Down
8 changes: 7 additions & 1 deletion frontend/types/gotypes.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,14 @@ declare global {
// wshrpc.OpenAIOptsType
type OpenAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeout?: number;
timeoutms?: number;
};

// wshrpc.OpenAIPacketType
Expand Down Expand Up @@ -413,10 +416,13 @@ declare global {
// wconfig.SettingsType
type SettingsType = {
"ai:*"?: boolean;
"ai:apitype"?: string;
"ai:baseurl"?: string;
"ai:apitoken"?: string;
"ai:name"?: string;
"ai:model"?: string;
"ai:orgid"?: string;
"ai:apiversion"?: string;
"ai:maxtokens"?: number;
"ai:timeoutms"?: number;
"term:*"?: boolean;
Expand Down
115 changes: 67 additions & 48 deletions pkg/waveai/waveai.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ import (
"io"
"log"
"os"
"regexp"
"strings"
"time"

"github.com/sashabaranov/go-openai"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud"
Expand All @@ -23,6 +26,7 @@ import (
const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF"
const DefaultAzureAPIVersion = "2023-05-15"

type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"`
Expand Down Expand Up @@ -52,43 +56,38 @@ type OpenAICloudReqPacketType struct {
MaxChoices int `json:"maxchoices,omitempty"`
}

type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
BlockId string `json:"blockid"`
}

func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}

func GetWSEndpoint() string {
return PCloudWSEndpoint
if !wavebase.IsDevMode() {
return PCloudWSEndpoint
return WCloudWSEndpoint
} else {
endpoint := os.Getenv(PCloudWSEndpointVarName)
endpoint := os.Getenv(WCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
}

const DefaultMaxTokens = 1000
const DefaultMaxTokens = 2048
const DefaultModel = "gpt-4o-mini"
const DefaultStreamChanSize = 10
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"

const CloudWebsocketConnectTimeout = 1 * time.Minute

func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
if opts == nil {
return true
}
return opts.BaseURL == "" && opts.APIToken == ""
}

func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
Expand All @@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
}

func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if IsCloudAIRequest(request.Opts) {
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
return RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return RunLocalCompletionStream(ctx, request)
}

func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint()
Expand Down Expand Up @@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
return rtn
}

// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}

func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
clientConfig.APIType = openai.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
clientConfig.APIType = openai.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
clientConfig.APIType = openai.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
clientConfig.APIType = openai.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else {
return fmt.Errorf("invalid api type %q", opts.APIType)
}
}

func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
Expand All @@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Expand Down Expand Up @@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
}()
return rtn
}

func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType {
var rtn []*wshrpc.OpenAIPacketType
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = string(choice.FinishReason)
rtn = append(rtn, choicePk)
}
return rtn
}

func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType {
errPk := MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}

func CreateTextPacket(text string) *wshrpc.OpenAIPacketType {
pk := MakeOpenAIPacket()
pk.Text = text
return pk
}
4 changes: 2 additions & 2 deletions pkg/wconfig/defaultconfig/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"ai:model": "gpt-4o-mini",
"ai:maxtokens": 1000,
"ai:timeoutms": 10000,
"ai:maxtokens": 2048,
"ai:timeoutms": 60000,
"autoupdate:enabled": true,
"autoupdate:installonquit": true,
"autoupdate:intervalms": 3600000,
Expand Down
3 changes: 3 additions & 0 deletions pkg/wconfig/metaconsts.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ package wconfig

const (
ConfigKey_AiClear = "ai:*"
ConfigKey_AiApiType = "ai:apitype"
ConfigKey_AiBaseURL = "ai:baseurl"
ConfigKey_AiApiToken = "ai:apitoken"
ConfigKey_AiName = "ai:name"
ConfigKey_AiModel = "ai:model"
ConfigKey_AiOrgID = "ai:orgid"
ConfigKey_AIApiVersion = "ai:apiversion"
ConfigKey_AiMaxTokens = "ai:maxtokens"
ConfigKey_AiTimeoutMs = "ai:timeoutms"

Expand Down
17 changes: 10 additions & 7 deletions pkg/wconfig/settingsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) {
}

type SettingsType struct {
AiClear bool `json:"ai:*,omitempty"`
AiBaseURL string `json:"ai:baseurl,omitempty"`
AiApiToken string `json:"ai:apitoken,omitempty"`
AiName string `json:"ai:name,omitempty"`
AiModel string `json:"ai:model,omitempty"`
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
AiClear bool `json:"ai:*,omitempty"`
AiApiType string `json:"ai:apitype,omitempty"`
AiBaseURL string `json:"ai:baseurl,omitempty"`
AiApiToken string `json:"ai:apitoken,omitempty"`
AiName string `json:"ai:name,omitempty"`
AiModel string `json:"ai:model,omitempty"`
AiOrgID string `json:"ai:orgid,omitempty"`
AIApiVersion string `json:"ai:apiversion,omitempty"`
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`

TermClear bool `json:"term:*,omitempty"`
TermFontSize float64 `json:"term:fontsize,omitempty"`
Expand Down
5 changes: 4 additions & 1 deletion pkg/wshrpc/wshrpctypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct {

type OpenAIOptsType struct {
Model string `json:"model"`
APIType string `json:"apitype,omitempty"`
APIToken string `json:"apitoken"`
OrgID string `json:"orgid,omitempty"`
APIVersion string `json:"apiversion,omitempty"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
TimeoutMs int `json:"timeoutms,omitempty"`
}

type OpenAIPacketType struct {
Expand Down
7 changes: 1 addition & 6 deletions pkg/wshrpc/wshserver/wshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
}

func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
log.Print("sending ai chat message to waveterm default endpoint with openai\n")
return waveai.RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return waveai.RunLocalCompletionStream(ctx, request)
return waveai.RunAICommand(ctx, request)
}

func MakePlotData(ctx context.Context, blockId string) error {
Expand Down

0 comments on commit f33028a

Please sign in to comment.