Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the normalizer to the ML API #77

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ services:
bash -c "python setup.py develop && \
mkdir -p models/styletts2 && \
aws s3 sync s3://uberduck-models-us-west-2/prototype/styletts2 models/styletts2 && \
uvicorn openduck_py.routers.ml:app --reload --host 0.0.0.0 --port 8001"
uvicorn openduck_py.routers.main:app --reload --host 0.0.0.0 --port 8000"
working_dir: /openduck-py/openduck-py
volumes:
- .:/openduck-py
ports:
- "8001:8001"
- "8000:8000"
env_file:
- .env.dev
runtime: nvidia
29 changes: 21 additions & 8 deletions openduck-py/openduck_py/routers/ml.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import io

from whisper import load_model
import numpy as np
from nemo_text_processing.text_normalization.normalize import Normalizer

from openduck_py.voices.styletts2 import styletts2_inference
from openduck_py.settings import OUTPUT_SAMPLE_RATE
from openduck_py.settings import OUTPUT_SAMPLE_RATE, IS_DEV

ml_router = APIRouter(prefix="/ml")

whisper_model = load_model("base.en")

# TODO (Matthew): Load the normalizer on IS_DEV but change the docker-compose to only reload the ML
# service if this file is changed
if IS_DEV:
normalize_text_fn = lambda x: x
else:
normalizer = Normalizer(input_case="cased", lang="en")
normalize_text_fn = normalizer.normalize


class TextInput(BaseModel):
text: str


@ml_router.post("/normalize")
async def normalize_text(text: TextInput):
return {"text": normalize_text_fn(text.text)}


@ml_router.post("/transcribe")
async def transcribe_audio(
Expand All @@ -32,12 +49,8 @@ async def transcribe_audio(
raise HTTPException(status_code=500, detail=str(e))


class TTSInput(BaseModel):
text: str


@ml_router.post("/tts")
async def text_to_speech(tts_input: TTSInput):
async def text_to_speech(tts_input: TextInput):
try:
audio_chunk = styletts2_inference(
text=tts_input.text,
Expand Down
27 changes: 13 additions & 14 deletions openduck-py/openduck_py/routers/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import multiprocessing
from time import time
from typing import Optional, Dict, Literal, AsyncIterator, AsyncGenerator
from typing import Optional, Dict, Literal, AsyncGenerator
import wave
import requests
from pathlib import Path
Expand All @@ -27,7 +27,6 @@
from openduck_py.settings import (
CHAT_MODEL,
CHUNK_SIZE,
IS_DEV,
LOG_TO_SLACK,
ML_API_URL,
OUTPUT_SAMPLE_RATE,
Expand All @@ -38,14 +37,6 @@
from openduck_py.utils.third_party_tts import aio_elevenlabs_tts
from openduck_py.logging.slack import log_audio_to_slack

if IS_DEV:
normalize_text = lambda x: x
else:
from nemo_text_processing.text_normalization.normalize import Normalizer

normalizer = Normalizer(input_case="cased", lang="en")
normalize_text = normalizer.normalize


try:
pipeline, inference = load_pipelines()
Expand All @@ -60,7 +51,6 @@
)

speaker_embedding = inference("aec-cartoon-degraded.wav")

audio_router = APIRouter(prefix="/audio")

Daily.init()
Expand All @@ -79,7 +69,6 @@ async def _transcribe(audio_data: np.ndarray) -> str:
async with httpx.AsyncClient() as client:
response = await client.post(url, files=files)

# Check the response status code
if response.status_code == 200:
return response.json()["text"]
else:
Expand All @@ -95,6 +84,17 @@ async def _inference(sentence: str) -> AsyncGenerator[bytes, None]:
yield chunk


async def _normalize_text(text: str) -> str:
url = f"{ML_API_URL}/ml/normalize"
async with httpx.AsyncClient() as client:
response = await client.post(url, json={"text": text})

if response.status_code == 200:
return response.json()["text"]
else:
raise Exception(f"Normalization failed with status code {response.status_code}")


class WavAppender:
def __init__(self, wav_file_path="output.wav"):
self.wav_file_path = wav_file_path
Expand Down Expand Up @@ -325,8 +325,7 @@ async def speak_response(
return

if self.tts_config.provider == "local":

normalized = normalize_text(response_text)
normalized = await _normalize_text(response_text)
t_normalize = time()
await log_event(
db,
Expand Down
3 changes: 1 addition & 2 deletions openduck-py/openduck_py/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
EMB_MATCH_THRESHOLD = 0.5
WS_SAMPLE_RATE = 16_000
OUTPUT_SAMPLE_RATE = 24_000
DEPLOY_ENV = os.environ.get("DEPLOY_ENV", "dev")
IS_DEV = DEPLOY_ENV == "dev"
IS_DEV = bool(os.environ["IS_DEV"])
ML_API_URL = os.environ["ML_API_URL"]
# Set to 1024 for the esp32, but larger CHUNK_SIZE is needed to prevent choppiness with the local client
CHUNK_SIZE = 10240
Expand Down
Loading