Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zachwe committed Feb 21, 2024
2 parents 04f74f6 + 795b453 commit dbabedb
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 93 deletions.
56 changes: 47 additions & 9 deletions clients/simple/simple_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import threading
import queue
from pathlib import Path
import asyncio
import websockets

import numpy as np
import sounddevice as sd
Expand All @@ -31,7 +33,7 @@
from pydub import AudioSegment
from uuid import uuid4

SAMPLE_RATE = 22050
SAMPLE_RATE = 24000
CHANNELS = 1

IDLE = "IDLE"
Expand All @@ -40,7 +42,7 @@
PLAYBACK = "PLAYBACK"
RECORDING_FILE = "recording.wav"
RESPONSE_FILE = "response.wav"
UBERDUCK_API = os.environ["UBERDUCK_API"]
UBERDUCK_API_HOST = os.environ["UBERDUCK_API_HOST"]

speech_file_path = Path(__file__).parent / "response.wav"
chat_history = [
Expand All @@ -66,12 +68,47 @@
session = str(uuid4())


async def uberduck_websocket():
uri = f"ws://{UBERDUCK_API_HOST}?session_id={session}"
print(uri)
async with websockets.connect(uri) as websocket:
print(f"[INFO] Sending audio to the server...")
with open(RECORDING_FILE, "rb") as file:
audio_content = file.read()
await websocket.send(audio_content)
print("[INFO] Audio sent to the server.")

async for message in websocket:
data = np.frombuffer(message, dtype=np.int16)
sd.play(data, 24000)
sd.wait()
print("[INFO] Playing received audio.")


async def uberduck_websocket():
async with websockets.connect(UBERDUCK_API) as websocket:
print(f"[INFO] Sending audio to the server...")
with open(RECORDING_FILE, "rb") as file:
audio_content = file.read()
await websocket.send(audio_content)
print("[INFO] Audio sent to the server.")

async for message in websocket:
if message == "done":
break
data = np.frombuffer(message, dtype=np.int16)
sd.play(data, 24000)
await sd.wait()
print("[INFO] Playing received audio.")


def uberduck_response():
uri = "http://" + UBERDUCK_API_HOST
with open(RECORDING_FILE, "rb") as file:
print(f"[INFO] Sending audio to the server...")
files = {"audio": (RECORDING_FILE, file, "audio/wav")}
payload = {"session_id": session}
response = requests.post(UBERDUCK_API, files=files, data=payload)
response = requests.post(uri, files=files, data=payload)
print(f"[INFO] Response received from the server: {response.status_code}")
if response.status_code == 200:
data = np.frombuffer(response.content, dtype=np.int16)
Expand Down Expand Up @@ -154,10 +191,10 @@ def play(self):
sd.wait()
print("[INFO] Playback finished. Press space to start recording.")

def start_processing(self):
async def start_processing(self):
print("[INFO] Processing...")
if USE_UBERDUCK:
uberduck_response()
uberduck_websocket()
else:
openai_response()
print("[INFO] Processing finished.")
Expand All @@ -174,7 +211,7 @@ def set_state(self, state):
def __str__(self) -> str:
return f"State: {self.state}"

def on_press(self, key):
async def on_press(self, key):
print("key: ", key)
if key == "space":
if self.state == IDLE:
Expand All @@ -184,16 +221,17 @@ def on_press(self, key):
elif self.state == RECORDING:
self.recorder.stop_recording()
self.set_state(PROCESSING)
self.recorder.start_processing()
self.set_state(PLAYBACK)
self.recorder.play()
await self.recorder.start_processing()
self.set_state(IDLE)

def run(self):
listen_keyboard(on_press=self.on_press)


if __name__ == "__main__":
startup_sound, fs = sf.read("startup.wav")
sd.play(startup_sound, fs)
sd.wait()
print("Press space to start recording.")
sm = StateMachine()
sm.run()
Binary file added clients/simple/startup.wav
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""add_chat_history
Revision ID: 42f7dfcde186
Revises: 00cad591e71a
Create Date: 2024-02-20 20:16:43.082215+00:00
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import sqlite

# revision identifiers, used by Alembic.
revision: str = '42f7dfcde186'
down_revision: Union[str, None] = '00cad591e71a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chat_history',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('session_id', sa.Text(), nullable=False),
sa.Column('history_json', sqlite.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_chat_history_user_id'), 'chat_history', ['user_id'], unique=False)
op.create_table('template_deployment',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('uuid', sa.Text(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('url_name', sa.String(), nullable=False),
sa.Column('display_name', sa.String(), nullable=True),
sa.Column('prompt', sqlite.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.Column('meta_json', sqlite.JSON(), nullable=True),
sa.Column('model', sa.String(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('uuid')
)
op.create_index('deployment_user_id_url_name_unique_not_deleted', 'template_deployment', ['user_id', 'url_name'], unique=True, sqlite_where=sa.text('deleted_at IS NULL'))
op.create_index(op.f('ix_template_deployment_id'), 'template_deployment', ['id'], unique=False)
op.create_index(op.f('ix_template_deployment_url_name'), 'template_deployment', ['url_name'], unique=False)
op.create_index(op.f('ix_template_deployment_user_id'), 'template_deployment', ['user_id'], unique=False)
op.create_table('template_prompt',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('uuid', sa.Text(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('url_name', sa.String(), nullable=False),
sa.Column('display_name', sa.String(), nullable=True),
sa.Column('prompt', sqlite.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.Column('meta_json', sqlite.JSON(), nullable=True),
sa.Column('model', sa.String(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_template_prompt_id'), 'template_prompt', ['id'], unique=False)
op.create_index(op.f('ix_template_prompt_url_name'), 'template_prompt', ['url_name'], unique=False)
op.create_index(op.f('ix_template_prompt_user_id'), 'template_prompt', ['user_id'], unique=False)
op.create_index(op.f('ix_template_prompt_uuid'), 'template_prompt', ['uuid'], unique=True)
op.create_index('prompt_user_id_url_name_unique_not_deleted', 'template_prompt', ['user_id', 'url_name'], unique=True, sqlite_where=sa.text('deleted_at IS NULL'))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('prompt_user_id_url_name_unique_not_deleted', table_name='template_prompt', sqlite_where=sa.text('deleted_at IS NULL'))
op.drop_index(op.f('ix_template_prompt_uuid'), table_name='template_prompt')
op.drop_index(op.f('ix_template_prompt_user_id'), table_name='template_prompt')
op.drop_index(op.f('ix_template_prompt_url_name'), table_name='template_prompt')
op.drop_index(op.f('ix_template_prompt_id'), table_name='template_prompt')
op.drop_table('template_prompt')
op.drop_index(op.f('ix_template_deployment_user_id'), table_name='template_deployment')
op.drop_index(op.f('ix_template_deployment_url_name'), table_name='template_deployment')
op.drop_index(op.f('ix_template_deployment_id'), table_name='template_deployment')
op.drop_index('deployment_user_id_url_name_unique_not_deleted', table_name='template_deployment', sqlite_where=sa.text('deleted_at IS NULL'))
op.drop_table('template_deployment')
op.drop_index(op.f('ix_chat_history_user_id'), table_name='chat_history')
op.drop_table('chat_history')
# ### end Alembic commands ###
4 changes: 1 addition & 3 deletions openduck-py/openduck_py/routers/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import FastAPI
from openduck_py.routers.voice import voice_router
from fastapi import FastAPI, WebSocket
from openduck_py.routers.templates import templates_router
from openduck_py.routers.voice import audio_router

Expand All @@ -17,7 +16,6 @@
debug=IS_DEV,
)

app.include_router(voice_router)
app.include_router(templates_router)
app.include_router(audio_router)

Expand Down
102 changes: 31 additions & 71 deletions openduck-py/openduck_py/routers/voice.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,48 @@
import io
import re
from tempfile import NamedTemporaryFile
from uuid import uuid4
from fastapi import APIRouter, Depends, UploadFile, File, Form
from fastapi import APIRouter, Depends, Query, WebSocket
from sqlalchemy import select
from starlette.responses import StreamingResponse
import whisper
import base64
from time import time

from scipy.io.wavfile import read, write
import numpy as np

from openduck_py.utils.third_party_tts import aio_polly_tts
from openduck_py.utils.s3 import download_file
from openduck_py.models import DBVoice, DBUser, DBChatHistory
from openduck_py.models import DBChatHistory
from openduck_py.db import get_db_async, AsyncSession
from openduck_py.voices import styletts2
from pydantic import BaseModel
from openduck_py.routers.templates import generate


voice_router = APIRouter(prefix="/voice")


@voice_router.post("/text-to-speech", include_in_schema=False)
async def text_to_speech(
db: AsyncSession = Depends(get_db_async),
):

raise NotImplementedError

styletts2.styletts2_inference(
text="Hello, my name is Matthew. How are you today?",
model_path="styletts2/rap_v1.pt",
model_bucket="uberduck-models-us-west-2",
config_path="styletts2/rap_v1_config.yml",
config_bucket="uberduck-models-us-west-2",
output_bucket="uberduck-audio-outputs",
output_path="test.wav",
style_prompt_path="511f17d1-8a30-4be8-86aa-4cdd8b0aed70.wav",
style_prompt_bucket="uberduck-audio-files",
)

voice_uuid = "906471f3-efa1-4410-978e-c105ac4fad61"
voice = await db.execute(
select(DBVoice).where(DBVoice.voice_uuid == voice_uuid).limit(1)
)
request_id = str(uuid4())
upload_path = f"{request_id}/output.mp3"
text = "Il était une fois, dans un petit village pittoresque en France, deux âmes solitaires dont les chemins étaient destinés à se croiser. Juliette, une jeune fleuriste passionnée par les couleurs et les parfums de ses fleurs, passait ses journées à embellir la vie des villageois avec ses bouquets enchanteurs. De l'autre côté du village vivait Étienne, un poète timide dont les vers capturaient la beauté et la mélancolie de la vie, mais qui gardait ses poèmes pour lui, craignant qu'ils ne soient pas à la hauteur du monde extérieur."
await aio_polly_tts(
text=text,
voice_id="Mathieu",
language_code="fr-FR",
engine="standard",
upload_path=upload_path,
output_format="mp3",
)
return dict(
uuid=request_id,
path=f"https://uberduck-audio-outputs.s3-us-west-2.amazonaws.com/{upload_path}",
)


model = whisper.load_model("tiny") # Fastest possible whisper model

audio_router = APIRouter(prefix="/audio")


@audio_router.post("/response", include_in_schema=False)
@audio_router.websocket("/response")
async def audio_response(
session_id: str = Form(None),
audio: UploadFile = File(None),
websocket: WebSocket,
session_id: str,
db: AsyncSession = Depends(get_db_async),
response_class=StreamingResponse,
):
await websocket.accept()

print("Session ID", session_id)
audio_data = await websocket.receive_bytes()
assert session_id is not None
t0 = time()

with NamedTemporaryFile() as temp_file:
data = await audio.read()
temp_file.write(data)
transcription = model.transcribe(temp_file.name)["text"]
def _transcribe():
with NamedTemporaryFile() as temp_file:
temp_file.write(audio_data)
transcription = model.transcribe(temp_file.name)["text"]
return transcription

from asgiref.sync import sync_to_async

_async_transcribe = sync_to_async(_transcribe)
transcription = await _async_transcribe()

if not transcription:
return

t_whisper = time()

Expand Down Expand Up @@ -117,20 +77,20 @@ async def audio_response(
chat.history_json["messages"] = messages
await db.commit()

audio_chunks = []
sentences = re.split(r"(?<=[.!?]) +", response_message.content)
for i in range(0, len(sentences), 2):
chunk_text = " ".join(sentences[i : i + 2])
audio_chunk = styletts2.styletts2_inference(text=chunk_text)
audio_chunks.append(audio_chunk)
audio = np.concatenate(audio_chunks)
audio = np.int16(audio * 32767) # Scale to 16-bit integer values
for sentence in sentences:
# TODO: deal with asyncio
audio_chunk = styletts2.styletts2_inference(text=sentence)
audio_chunk_bytes = np.int16(
audio_chunk * 32767
).tobytes() # Scale to 16-bit integer values
await websocket.send_bytes(audio_chunk_bytes)

t_styletts = time()

print("Whisper", t_whisper - t0)
print("GPT", t_gpt - t_whisper)
print("StyleTTS2", t_styletts - t_gpt)

output = StreamingResponse(io.BytesIO(audio), media_type="application/octet-stream")
return output
# await websocket.send_text("done")
await websocket.close()
16 changes: 6 additions & 10 deletions openduck-py/openduck_py/scripts/run_styletts2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import soundfile as sf
from openduck_py.voices import styletts2
from openduck_py.voices.settings import SAMPLE_RATE


styletts2.styletts2_inference(
text="Hello, my name is Matthew. How are you today?",
model_path="styletts2/rap_v1.pt",
model_bucket="uberduck-models-us-west-2",
config_path="styletts2/rap_v1_config.yml",
config_bucket="uberduck-models-us-west-2",
output_bucket="uberduck-audio-outputs",
output_path="test.wav",
style_prompt_path="511f17d1-8a30-4be8-86aa-4cdd8b0aed70.wav",
style_prompt_bucket="uberduck-audio-files",
audio = styletts2.styletts2_inference(
text="Hey, I'm the Uberduck! What do you want to learn about today?"
)

sf.write("startup.wav", audio, SAMPLE_RATE) # Assuming the sample rate is 22050 Hz
1 change: 1 addition & 0 deletions openduck-py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
alembic
aiosqlite
aioboto3
asgiref
asyncpg
azure-cognitiveservices-speech
databases
Expand Down

0 comments on commit dbabedb

Please sign in to comment.