1
- import json
2
1
import logging
3
2
import shlex
4
3
import subprocess
12
11
import tyro
13
12
import uvicorn
14
13
from attr import dataclass
15
- from fastapi import Request
14
+ from fastapi import File , Form , HTTPException , UploadFile , status
16
15
from fastapi .responses import Response
17
16
18
17
from fam .llm .fast_inference import TTS
@@ -50,55 +49,55 @@ class _GlobalState:
50
49
GlobalState = _GlobalState ()
51
50
52
51
53
- @dataclass (frozen = True )
54
- class TTSRequest :
55
- text : str
56
- speaker_ref_path : Optional [str ] = None
57
- guidance : float = 3.0
58
- top_p : float = 0.95
59
- top_k : Optional [int ] = None
60
-
61
-
62
52
@app .get ("/health" )
63
53
async def health_check ():
64
54
return {"status" : "ok" }
65
55
66
56
67
57
@app .post ("/tts" , response_class = Response )
68
- async def text_to_speech (req : Request ):
69
- audiodata = await req .body ()
70
- payload = None
58
+ async def text_to_speech (
59
+ text : str = Form (...),
60
+ speaker_ref_path : Optional [str ] = Form (None ),
61
+ guidance : float = Form (3.0 ),
62
+ top_p : float = Form (0.95 ),
63
+ audiodata : Optional [UploadFile ] = File (None ),
64
+ ):
65
+ # Ensure at least one of speaker_ref_path or audiodata is provided
66
+ if not audiodata and not speaker_ref_path :
67
+ raise HTTPException (
68
+ status_code = status .HTTP_400_BAD_REQUEST ,
69
+ detail = "Either an audio file or a speaker reference path must be provided." ,
70
+ )
71
+
71
72
wav_out_path = None
72
73
73
74
try :
74
- headers = req .headers
75
- payload = headers ["X-Payload" ]
76
- payload = json .loads (payload )
77
- tts_req = TTSRequest (** payload )
78
75
with tempfile .NamedTemporaryFile (suffix = ".wav" ) as wav_tmp :
79
- if tts_req . speaker_ref_path is None :
76
+ if speaker_ref_path is None :
80
77
wav_path = _convert_audiodata_to_wav_path (audiodata , wav_tmp )
81
78
check_audio_file (wav_path )
82
79
else :
83
80
# TODO: fix
84
- wav_path = tts_req . speaker_ref_path
81
+ wav_path = speaker_ref_path
85
82
86
83
if wav_path is None :
87
84
warnings .warn ("Running without speaker reference" )
88
- assert tts_req . guidance is None
85
+ assert guidance is None
89
86
90
87
wav_out_path = GlobalState .tts .synthesise (
91
- text = tts_req . text ,
88
+ text = text ,
92
89
spk_ref_path = wav_path ,
93
- top_p = tts_req . top_p ,
94
- guidance_scale = tts_req . guidance ,
90
+ top_p = top_p ,
91
+ guidance_scale = guidance ,
95
92
)
96
93
97
94
with open (wav_out_path , "rb" ) as f :
98
95
return Response (content = f .read (), media_type = "audio/wav" )
99
96
except Exception as e :
100
97
# traceback_str = "".join(traceback.format_tb(e.__traceback__))
101
- logger .exception (f"Error processing request { payload } " )
98
+ logger .exception (
99
+ f"Error processing request. text: { text } , speaker_ref_path: { speaker_ref_path } , guidance: { guidance } , top_p: { top_p } "
100
+ )
102
101
return Response (
103
102
content = "Something went wrong. Please try again in a few mins or contact us on Discord" ,
104
103
status_code = 500 ,
@@ -108,9 +107,9 @@ async def text_to_speech(req: Request):
108
107
Path (wav_out_path ).unlink (missing_ok = True )
109
108
110
109
111
- def _convert_audiodata_to_wav_path (audiodata , wav_tmp ):
110
+ def _convert_audiodata_to_wav_path (audiodata : UploadFile , wav_tmp ):
112
111
with tempfile .NamedTemporaryFile () as unknown_format_tmp :
113
- if unknown_format_tmp .write (audiodata ) == 0 :
112
+ if unknown_format_tmp .write (audiodata . read () ) == 0 :
114
113
return None
115
114
unknown_format_tmp .flush ()
116
115
@@ -129,7 +128,11 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
129
128
logging .root .setLevel (logging .INFO )
130
129
131
130
GlobalState .config = tyro .cli (ServingConfig )
132
- GlobalState .tts = TTS (seed = GlobalState .config .seed , quantisation_mode = GlobalState .config .quantisation_mode )
131
+ GlobalState .tts = TTS (
132
+ seed = GlobalState .config .seed ,
133
+ quantisation_mode = GlobalState .config .quantisation_mode ,
134
+ telemetry_origin = "api_server" ,
135
+ )
133
136
134
137
app .add_middleware (
135
138
fastapi .middleware .cors .CORSMiddleware ,
0 commit comments