Skip to content

Commit dab7454

Browse files
sidroopdaskaSiddharth Sharmasiddharth sharma
authored
feat: add telemetry origin (#141)
* feat: anonymised telemetry to track usage patterns * add: PR suggestions * feat: add telemetry origin field * feat: fix POST definition --------- Co-authored-by: Siddharth Sharma <[email protected]> Co-authored-by: siddharth sharma <[email protected]>
1 parent eebdcc6 commit dab7454

File tree

4 files changed

+41
-30
lines changed

4 files changed

+41
-30
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ docker-compose up -d ui && docker-compose ps && docker-compose logs -f
3030

3131
Server
3232
```bash
33+
# navigate to <URL>/docs for API definitions
3334
docker-compose up -d server && docker-compose ps && docker-compose logs -f
3435
```
3536

@@ -102,7 +103,10 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s
102103
```bash
103104
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
104105
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
106+
107+
# navigate to <URL>/docs for API definitions
105108
poetry run python serving.py
109+
106110
poetry run python app.py
107111
```
108112

app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from fam.llm.utils import check_audio_file
1414

1515
#### setup model
16-
TTS_MODEL = tyro.cli(TTS)
16+
TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"])
1717

1818
#### setup interface
1919
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
2020
MAX_CHARS = 220
2121
PRESET_VOICES = {
2222
# female
23-
"Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3",
23+
"Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
2424
# male
2525
"Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
2626
"Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",

fam/llm/fast_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
output_dir: str = "outputs",
4747
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
4848
first_stage_path: Optional[str] = None,
49+
telemetry_origin: Optional[str] = None,
4950
):
5051
"""
5152
Initialise the TTS model.
@@ -60,6 +61,7 @@ def __init__(
6061
- int4 for int4 weight-only quantisation,
6162
- int8 for int8 weight-only quantisation.
6263
first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
64+
telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog.
6365
"""
6466

6567
# NOTE: this needs to come first so that we don't change global state when we want to use
@@ -104,6 +106,7 @@ def __init__(
104106
self._seed = seed
105107
self._quantisation_mode = quantisation_mode
106108
self._model_name = model_name
109+
self._telemetry_origin = telemetry_origin
107110

108111
def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
109112
"""
@@ -183,6 +186,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
183186
"seed": self._seed,
184187
"first_stage_ckpt": self._first_stage_ckpt,
185188
"gpu": torch.cuda.get_device_name(0),
189+
"telemetry_origin": self._telemetry_origin,
186190
},
187191
)
188192
)

serving.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import logging
32
import shlex
43
import subprocess
@@ -12,7 +11,7 @@
1211
import tyro
1312
import uvicorn
1413
from attr import dataclass
15-
from fastapi import Request
14+
from fastapi import File, Form, HTTPException, UploadFile, status
1615
from fastapi.responses import Response
1716

1817
from fam.llm.fast_inference import TTS
@@ -50,55 +49,55 @@ class _GlobalState:
5049
GlobalState = _GlobalState()
5150

5251

53-
@dataclass(frozen=True)
54-
class TTSRequest:
55-
text: str
56-
speaker_ref_path: Optional[str] = None
57-
guidance: float = 3.0
58-
top_p: float = 0.95
59-
top_k: Optional[int] = None
60-
61-
6252
@app.get("/health")
6353
async def health_check():
6454
return {"status": "ok"}
6555

6656

6757
@app.post("/tts", response_class=Response)
68-
async def text_to_speech(req: Request):
69-
audiodata = await req.body()
70-
payload = None
58+
async def text_to_speech(
59+
text: str = Form(...),
60+
speaker_ref_path: Optional[str] = Form(None),
61+
guidance: float = Form(3.0),
62+
top_p: float = Form(0.95),
63+
audiodata: Optional[UploadFile] = File(None),
64+
):
65+
# Ensure at least one of speaker_ref_path or audiodata is provided
66+
if not audiodata and not speaker_ref_path:
67+
raise HTTPException(
68+
status_code=status.HTTP_400_BAD_REQUEST,
69+
detail="Either an audio file or a speaker reference path must be provided.",
70+
)
71+
7172
wav_out_path = None
7273

7374
try:
74-
headers = req.headers
75-
payload = headers["X-Payload"]
76-
payload = json.loads(payload)
77-
tts_req = TTSRequest(**payload)
7875
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
79-
if tts_req.speaker_ref_path is None:
76+
if speaker_ref_path is None:
8077
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
8178
check_audio_file(wav_path)
8279
else:
8380
# TODO: fix
84-
wav_path = tts_req.speaker_ref_path
81+
wav_path = speaker_ref_path
8582

8683
if wav_path is None:
8784
warnings.warn("Running without speaker reference")
88-
assert tts_req.guidance is None
85+
assert guidance is None
8986

9087
wav_out_path = GlobalState.tts.synthesise(
91-
text=tts_req.text,
88+
text=text,
9289
spk_ref_path=wav_path,
93-
top_p=tts_req.top_p,
94-
guidance_scale=tts_req.guidance,
90+
top_p=top_p,
91+
guidance_scale=guidance,
9592
)
9693

9794
with open(wav_out_path, "rb") as f:
9895
return Response(content=f.read(), media_type="audio/wav")
9996
except Exception as e:
10097
# traceback_str = "".join(traceback.format_tb(e.__traceback__))
101-
logger.exception(f"Error processing request {payload}")
98+
logger.exception(
99+
f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}"
100+
)
102101
return Response(
103102
content="Something went wrong. Please try again in a few mins or contact us on Discord",
104103
status_code=500,
@@ -108,9 +107,9 @@ async def text_to_speech(req: Request):
108107
Path(wav_out_path).unlink(missing_ok=True)
109108

110109

111-
def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
110+
def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp):
112111
with tempfile.NamedTemporaryFile() as unknown_format_tmp:
113-
if unknown_format_tmp.write(audiodata) == 0:
112+
if unknown_format_tmp.write(audiodata.read()) == 0:
114113
return None
115114
unknown_format_tmp.flush()
116115

@@ -129,7 +128,11 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
129128
logging.root.setLevel(logging.INFO)
130129

131130
GlobalState.config = tyro.cli(ServingConfig)
132-
GlobalState.tts = TTS(seed=GlobalState.config.seed, quantisation_mode=GlobalState.config.quantisation_mode)
131+
GlobalState.tts = TTS(
132+
seed=GlobalState.config.seed,
133+
quantisation_mode=GlobalState.config.quantisation_mode,
134+
telemetry_origin="api_server",
135+
)
133136

134137
app.add_middleware(
135138
fastapi.middleware.cors.CORSMiddleware,

0 commit comments

Comments
 (0)