forked from norbertkross/fastapi-openvoice-tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
98 lines (71 loc) · 3.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from typing import Union,Optional
from src.generate_audio import AudioGenerator
from fastapi import FastAPI,Response
from fastapi.responses import FileResponse,StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel,Field
app = FastAPI()
# Allow all origins for demonstration purposes
# You should restrict this in a production environment
origins = [
"*",
"http://localhost",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AudioRequest(BaseModel):
text: str
voice: Optional[str] = Field('testimonial', description="The voioce tone from all the avalable voices [testimonial,female_1]")
outputFormat: Optional[str] = Field('mp3', description="The generated audio encoding, supports 'raw' | 'mp3' | 'wav' | 'ogg' | 'flac' | 'mulaw'")
speed: Optional[float] = Field(1.0, description="Playback rate of generated speech")
# Global variables to store the model and speaker IDs
target_se = None
model = None
source_se = None
speaker_id = None
tone_color_converter = None
target_se_default_male = None
target_se_default_female = None
# Dependency to initialize the TTS model and speaker IDs
def globalDataSetter():
audioGenerator = AudioGenerator()
target_se,tone_color_converter,target_se_default_male,target_se_default_female = audioGenerator.targetSEreference()
model,source_se,speaker_id = audioGenerator.generatorModelsAndParamsInitializer()
return target_se,model,source_se,speaker_id,tone_color_converter,target_se_default_male,target_se_default_female
@app.on_event("startup")
def startup_event():
global target_se,model,source_se,speaker_id,tone_color_converter,target_se_default_male,target_se_default_female
target_se,model,source_se,speaker_id,tone_color_converter,target_se_default_male,target_se_default_female = globalDataSetter()
print("startup call.... end!")
@app.get("/")
def read_root():
return "HELLO WORLD"
@app.post("/generate")
async def generate(request: AudioRequest):
audioGenerator = AudioGenerator()
text = request.text
# file_path = audioGenerator.generateAudio(text,target_se,model,source_se,speaker_id,tone_color_converter)
targetSEtoUse = audioGenerator.determineVoiceToUse(request.voice,target_se_default_male,target_se_default_female)
file_path = await audioGenerator.generateAudio(text,targetSEtoUse,model,source_se,speaker_id,tone_color_converter,request.speed,request.outputFormat)
if not os.path.exists(file_path):
print("File not found")
return Response(content="File not found", status_code=404)
async def file_streamer():
try:
async for chunk in audioGenerator.async_iterator(file_path):
yield chunk
finally:
parts = file_path.split("output_v2_")
pathToDelete = ''.join(parts)
print(f'path res: {pathToDelete}')
await audioGenerator.delete_file(pathToDelete)
await audioGenerator.delete_file(file_path)
# Return the audio file as a byte stream using StreamingResponse
return StreamingResponse(content=file_streamer(), media_type=f'audio/{request.outputFormat}')