From 22906a3743cba9accd83bb142f23fd5ca26a8af2 Mon Sep 17 00:00:00 2001 From: mo Date: Sun, 2 Jun 2024 17:48:04 +0800 Subject: [PATCH 1/3] feat: Support Ali TTS --- controller/relay.go | 2 +- relay/adaptor/aiproxy/adaptor.go | 7 ++ relay/adaptor/ali/adaptor.go | 85 ++++++++++++++++++ relay/adaptor/ali/audio-speech.go | 21 +++++ relay/adaptor/ali/constants.go | 44 +++++++++ relay/adaptor/ali/main.go | 32 +++++++ relay/adaptor/anthropic/adaptor.go | 7 ++ relay/adaptor/aws/adapter.go | 7 ++ relay/adaptor/baidu/adaptor.go | 7 ++ relay/adaptor/cloudflare/adaptor.go | 7 ++ relay/adaptor/cohere/adaptor.go | 7 ++ relay/adaptor/coze/adaptor.go | 7 ++ relay/adaptor/deepl/adaptor.go | 7 ++ relay/adaptor/gemini/adaptor.go | 7 ++ relay/adaptor/interface.go | 1 + relay/adaptor/ollama/adaptor.go | 7 ++ relay/adaptor/openai/adaptor.go | 20 +++++ relay/adaptor/openai/audio.go | 26 ++++++ relay/adaptor/palm/adaptor.go | 7 ++ relay/adaptor/tencent/adaptor.go | 7 ++ relay/adaptor/xunfei/adaptor.go | 7 ++ relay/adaptor/zhipu/adaptor.go | 7 ++ relay/billing/ratio/model.go | 44 +++++++++ relay/controller/audio.go | 135 ++++++++++++++++++++++++++++ relay/controller/helper.go | 10 +++ relay/model/audio.go | 9 ++ 26 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 relay/adaptor/ali/audio-speech.go create mode 100644 relay/adaptor/openai/audio.go create mode 100644 relay/model/audio.go diff --git a/controller/relay.go b/controller/relay.go index 5d8ac69039..4ba8899f49 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,7 +29,7 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: - fallthrough + err = controller.RelayAudioSpeechHelper(c) case relaymode.AudioTranslation: fallthrough case relaymode.AudioTranscription: diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index 42d49c0a45..4269e39aca 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -45,6 +45,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index 4aa8a11ac3..a9a9559842 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -76,7 +76,19 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return aliRequest, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertTextToSpeechRequest(*request) + return aliRequest, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + if meta.Mode == relaymode.AudioSpeech { + return a.DoWSSRequest(c, meta, requestBody) + } return adaptor.DoRequestHelper(a, c, meta, requestBody) } @@ -89,6 +101,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, usage = EmbeddingHandler(c, resp) case relaymode.ImagesGenerations: err, usage = ImageHandler(c, resp) + case relaymode.AudioSpeech: + err, usage = AudioSpeechHandler(c, resp) default: err, usage = Handler(c, resp) } @@ -103,3 +117,74 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "ali" } + +func (a *Adaptor) DoWSSRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + baseURL := "wss://dashscope.aliyuncs.com/api-ws/v1/inference" + var usage Usage + // Create an empty http.Response object + response := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(nil), + } + + conn, _, err := websocket.DefaultDialer.Dial(baseURL, http.Header{"Authorization": {"Bearer " + meta.APIKey}}) + if err != nil { + return response, errors.New("ali_wss_conn_failed") + } + defer conn.Close() + + var requestBodyBytes []byte + requestBodyBytes, err = io.ReadAll(requestBody) + if err != nil { + return response, errors.New("ali_failed_to_read_request_body") + } + + // Convert JSON strings to map[string]interface{} + var requestBodyMap map[string]interface{} + err = json.Unmarshal(requestBodyBytes, &requestBodyMap) + if err != nil { + return response, errors.New("ali_failed_to_parse_request_body") + } + + if err := conn.WriteJSON(requestBodyMap); err != nil { + return response, errors.New("ali_wss_write_msg_failed") + } + + const chunkSize = 1024 + + for { + messageType, audioData, err := conn.ReadMessage() + if err != nil { + if err == io.EOF { + break + } + return response, errors.New("ali_wss_read_msg_failed") + } + + var msg WSSMessage + switch messageType { + case websocket.TextMessage: + err = json.Unmarshal(audioData, &msg) + if msg.Header.Event == "task-finished" { + response.StatusCode = http.StatusOK + usage.TotalTokens = msg.Payload.Usage.Characters + return response, nil + } + case websocket.BinaryMessage: + for i := 0; i < len(audioData); i += chunkSize { + end := i + chunkSize + if end > len(audioData) { + end = len(audioData) + } + chunk := audioData[i:end] + + _, writeErr := c.Writer.Write(chunk) + if writeErr != nil { + return response, errors.New("wss_write_chunk_failed") + } + } + } + } + + return response, nil +} diff --git a/relay/adaptor/ali/audio-speech.go b/relay/adaptor/ali/audio-speech.go new file mode 100644 index 0000000000..7dcd976d29 --- /dev/null +++ b/relay/adaptor/ali/audio-speech.go @@ -0,0 +1,21 @@ +package ali + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "net/http" +) + +func AudioSpeechHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/adaptor/ali/constants.go b/relay/adaptor/ali/constants.go index 3f24ce2e14..c97c345e69 100644 --- a/relay/adaptor/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -4,4 +4,48 @@ var ModelList = []string{ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "text-embedding-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", + + "sambert-zhichu-v1", + "sambert-zhiwei-v1", + "sambert-zhixiang-v1", + "sambert-zhide-v1", + "sambert-zhijia-v1", + "sambert-zhinan-v1", + "sambert-zhiqi-v1", + "sambert-zhiqian-v1", + "sambert-zhiru-v1", + "sambert-zhimiao-emo-v1", + "sambert-zhida-v1", + "sambert-zhifei-v1", + "sambert-zhigui-v1", + "sambert-zhihao-v1", + "sambert-zhijing-v1", + "sambert-zhilun-v1", + "sambert-zhimao-v1", + "sambert-zhiming-v1", + "sambert-zhimo-v1", + "sambert-zhina-v1", + "sambert-zhishu-v1", + "sambert-zhishuo-v1", + "sambert-zhistella-v1", + "sambert-zhiting-v1", + "sambert-zhixiao-v1", + "sambert-zhiya-v1", + "sambert-zhiye-v1", + "sambert-zhiying-v1", + "sambert-zhiyuan-v1", + "sambert-zhiyue-v1", + "sambert-camila-v1", + "sambert-perla-v1", + "sambert-indah-v1", + "sambert-clara-v1", + "sambert-hanna-v1", + "sambert-beth-v1", + "sambert-betty-v1", + "sambert-cally-v1", + "sambert-cindy-v1", + "sambert-eva-v1", + "sambert-donna-v1", + "sambert-brian-v1", + "sambert-waan-v1", } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 0462c26b06..a054d4ebda 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -77,6 +78,37 @@ func ConvertImageRequest(request model.ImageRequest) *ImageRequest { return &imageRequest } +func ConvertTextToSpeechRequest(request model.TextToSpeechRequest) *WSSMessage { + var ttsRequest WSSMessage + ttsRequest.Header.Action = "run-task" + ttsRequest.Header.Streaming = "out" + ttsRequest.Header.TaskID = uuid.New().String() + ttsRequest.Payload.Function = "SpeechSynthesizer" + ttsRequest.Payload.Input.Text = request.Input + ttsRequest.Payload.Model = request.Model + ttsRequest.Payload.Parameters.Format = "wav" + //ttsRequest.Payload.Parameters.SampleRate = 48000 + ttsRequest.Payload.Parameters.Rate = 1.0 + ttsRequest.Payload.Task = "tts" + ttsRequest.Payload.TaskGroup = "audio" + + format := map[string]bool{ + "pcm": true, + "wav": true, + "mp3": true, + } + + if _, ok := format[request.ResponseFormat]; ok { + ttsRequest.Payload.Parameters.Format = request.ResponseFormat + } + + if 0.5 <= request.Speed && request.Speed <= 2 { + ttsRequest.Payload.Parameters.Rate = request.Speed + } + + return &ttsRequest +} + func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index b1136e8462..c59808a38a 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -48,6 +48,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index 7245d3d9fe..ecc2e02f5e 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -57,6 +57,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return nil, nil } diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go index 15306b95a4..19af64dd1a 100644 --- a/relay/adaptor/baidu/adaptor.go +++ b/relay/adaptor/baidu/adaptor.go @@ -116,6 +116,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 6ff6b0d344..d3ae2b29c9 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -44,6 +44,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go index 6fdb1b0441..523bf76174 100644 --- a/relay/adaptor/cohere/adaptor.go +++ b/relay/adaptor/cohere/adaptor.go @@ -42,6 +42,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go index 44f560e8ef..95f84f0159 100644 --- a/relay/adaptor/coze/adaptor.go +++ b/relay/adaptor/coze/adaptor.go @@ -45,6 +45,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go index d018a0966f..ff4efec487 100644 --- a/relay/adaptor/deepl/adaptor.go +++ b/relay/adaptor/deepl/adaptor.go @@ -46,6 +46,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 12f48c715a..aff89287d7 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -66,6 +66,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go index 01b2e2cbad..f0af3f1418 100644 --- a/relay/adaptor/interface.go +++ b/relay/adaptor/interface.go @@ -14,6 +14,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) ConvertImageRequest(request *model.ImageRequest) (any, error) + ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go index 66702c5dde..4c726d35cb 100644 --- a/relay/adaptor/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -55,6 +55,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 2e2e410045..5b2d14bbb8 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -31,6 +31,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) return fullRequestURL, nil + } else if meta.Mode == relaymode.AudioTranscription { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil + } else if meta.Mode == relaymode.AudioSpeech { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil } // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api @@ -54,6 +62,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me adaptor.SetupCommonRequestHeader(c, req, meta) if meta.ChannelType == channeltype.Azure { req.Header.Set("api-key", meta.APIKey) + if meta.Mode == relaymode.AudioTranscription || meta.Mode == relaymode.AudioSpeech { + req.ContentLength = c.Request.ContentLength + } return nil } req.Header.Set("Authorization", "Bearer "+meta.APIKey) @@ -78,6 +89,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } @@ -97,6 +115,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.AudioSpeech: + err, _ = TextToSpeechHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } diff --git a/relay/adaptor/openai/audio.go b/relay/adaptor/openai/audio.go new file mode 100644 index 0000000000..674f3e8da8 --- /dev/null +++ b/relay/adaptor/openai/audio.go @@ -0,0 +1,26 @@ +package openai + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func TextToSpeechHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var err error + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/adaptor/palm/adaptor.go b/relay/adaptor/palm/adaptor.go index 98aa3e187f..0ecc6744a1 100644 --- a/relay/adaptor/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -43,6 +43,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index a97476d6cd..70cb059c5a 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -59,6 +59,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go index 3af97831d3..174c065909 100644 --- a/relay/adaptor/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -54,6 +54,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3f7..5e4317f5e6 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -92,6 +92,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return newRequest, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 3b289499bf..ae37fce81d 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -119,6 +119,50 @@ var ModelRatio = map[string]float64{ "ali-stable-diffusion-xl": 8, "ali-stable-diffusion-v1.5": 8, "wanx-v1": 8, + // https://help.aliyun.com/zh/dashscope/developer-reference/sambert-speech-synthesis-metered-billing?spm=a2c4g.11186623.0.0.12a52e5cvlyKYj + "sambert-zhichu-v1": 0.1 * RMB, // 1 RMB / 10K characters -> 0.1 RMB / 1K characters -> 0.1 RMB / 1K tokens + "sambert-zhiwei-v1": 0.1 * RMB, + "sambert-zhixiang-v1": 0.1 * RMB, + "sambert-zhide-v1": 0.1 * RMB, + "sambert-zhijia-v1": 0.1 * RMB, + "sambert-zhinan-v1": 0.1 * RMB, + "sambert-zhiqi-v1": 0.1 * RMB, + "sambert-zhiqian-v1": 0.1 * RMB, + "sambert-zhiru-v1": 0.1 * RMB, + "sambert-zhimiao-emo-v1": 0.1 * RMB, + "sambert-zhida-v1": 0.1 * RMB, + "sambert-zhifei-v1": 0.1 * RMB, + "sambert-zhigui-v1": 0.1 * RMB, + "sambert-zhihao-v1": 0.1 * RMB, + "sambert-zhijing-v1": 0.1 * RMB, + "sambert-zhilun-v1": 0.1 * RMB, + "sambert-zhimao-v1": 0.1 * RMB, + "sambert-zhiming-v1": 0.1 * RMB, + "sambert-zhimo-v1": 0.1 * RMB, + "sambert-zhina-v1": 0.1 * RMB, + "sambert-zhishu-v1": 0.1 * RMB, + "sambert-zhishuo-v1": 0.1 * RMB, + "sambert-zhistella-v1": 0.1 * RMB, + "sambert-zhiting-v1": 0.1 * RMB, + "sambert-zhixiao-v1": 0.1 * RMB, + "sambert-zhiya-v1": 0.1 * RMB, + "sambert-zhiye-v1": 0.1 * RMB, + "sambert-zhiying-v1": 0.1 * RMB, + "sambert-zhiyuan-v1": 0.1 * RMB, + "sambert-zhiyue-v1": 0.1 * RMB, + "sambert-camila-v1": 0.1 * RMB, + "sambert-perla-v1": 0.1 * RMB, + "sambert-indah-v1": 0.1 * RMB, + "sambert-clara-v1": 0.1 * RMB, + "sambert-hanna-v1": 0.1 * RMB, + "sambert-beth-v1": 0.1 * RMB, + "sambert-betty-v1": 0.1 * RMB, + "sambert-cally-v1": 0.1 * RMB, + "sambert-cindy-v1": 0.1 * RMB, + "sambert-eva-v1": 0.1 * RMB, + "sambert-donna-v1": 0.1 * RMB, + "sambert-brian-v1": 0.1 * RMB, + "sambert-waan-v1": 0.1 * RMB, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 8f9708d080..a872cd8811 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" @@ -26,6 +27,140 @@ import ( "strings" ) +func RelayAudioSpeechHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + audioModel := "whisper-1" + + tokenId := c.GetInt(ctxkey.TokenId) + channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt(ctxkey.Id) + group := c.GetString(ctxkey.Group) + tokenName := c.GetString(ctxkey.TokenName) + + ttsRequest, err := getTextToSpeechRequest(c) + if err != nil { + logger.Errorf(ctx, "getTextToSpeechRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_tts_request", http.StatusBadRequest) + } + + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + modelRatio := billingratio.GetModelRatio(audioModel) + groupRatio := billingratio.GetGroupRatio(group) + ratio := modelRatio * groupRatio + + preConsumedQuota := int64(float64(len(ttsRequest.Input)) * ratio) + quota := preConsumedQuota + + userQuota, err := model.CacheGetUserQuota(ctx, userId) + if err != nil { + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + + // Check if user quota is enough + if userQuota-preConsumedQuota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() + + // map model name + modelMapping := c.GetString(ctxkey.ModelMapping) + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + var requestBody io.Reader + + switch meta.ChannelType { + case channeltype.Ali: + finalRequest, err := adaptor.ConvertTextToSpeechRequest(ttsRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + default: + requestBody = c.Request.Body + } + + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + if resp.StatusCode != http.StatusOK { + return RelayErrorHandler(resp) + } + succeed = true + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486cb..030e21a5e1 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -58,6 +58,16 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e return imageRequest, nil } +func getTextToSpeechRequest(c *gin.Context) (*relaymodel.TextToSpeechRequest, error) { + ttsRequest := &relaymodel.TextToSpeechRequest{} + err := common.UnmarshalBodyReusable(c, ttsRequest) + if err != nil { + return nil, err + } + + return ttsRequest, nil +} + func isValidImageSize(model string, size string) bool { if model == "cogview-3" { return true diff --git a/relay/model/audio.go b/relay/model/audio.go new file mode 100644 index 0000000000..7542036add --- /dev/null +++ b/relay/model/audio.go @@ -0,0 +1,9 @@ +package model + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} From 82ff2857bc229a999a8c18c6789cd79130aa2352 Mon Sep 17 00:00:00 2001 From: mo Date: Sun, 2 Jun 2024 17:48:37 +0800 Subject: [PATCH 2/3] feat: Support Ali TTS --- relay/adaptor/ali/adaptor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index a9a9559842..787855b692 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -1,9 +1,11 @@ package ali import ( + "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" From 7ed4bc29a02bd7f6ce2212007a22180e8fe1afed Mon Sep 17 00:00:00 2001 From: mo Date: Sun, 2 Jun 2024 17:53:25 +0800 Subject: [PATCH 3/3] feat: Support Ali TTS --- relay/controller/audio.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index a872cd8811..81f63e8df0 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -30,7 +30,7 @@ import ( func RelayAudioSpeechHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) - audioModel := "whisper-1" + audioModel := "tts-1" tokenId := c.GetInt(ctxkey.TokenId) channelId := c.GetInt(ctxkey.ChannelId)