Skip to content

Commit

Permalink
load models from disk instead of S3
Browse files Browse the repository at this point in the history
This really speeds up the server startup time
  • Loading branch information
matthewkennedy5 committed Feb 21, 2024
1 parent e69d19c commit d90c861
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 50 deletions.
3 changes: 1 addition & 2 deletions openduck-py/openduck_py/routers/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from openduck_py.voices import styletts2
from openduck_py.routers.templates import generate

model = whisper.load_model("tiny") # Fastest possible whisper model
model = whisper.load_model("medium") # Fastest possible whisper model

audio_router = APIRouter(prefix="/audio")

Expand Down Expand Up @@ -92,5 +92,4 @@ def _transcribe():
print("GPT", t_gpt - t_whisper)
print("StyleTTS2", t_styletts - t_gpt)

# await websocket.send_text("done")
await websocket.close()
62 changes: 14 additions & 48 deletions openduck-py/openduck_py/voices/styletts2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def load_params(path):
return params_whole["net"]


def _load_model(
model_bucket, model_path, text_aligner, pitch_extractor, plbert, model_params
):
def load_model(model_path, text_aligner, pitch_extractor, plbert, model_params):
# NOTE (Sam): building the model prior to loading makes using the "model_args" key to store the config not work.
model = build_model(
recursive_munch(model_params),
Expand All @@ -80,9 +78,12 @@ def _load_model(

_ = [model[key].to(DEVICE) for key in model]

state_dict = load_object_from_s3(
s3_key=model_path, s3_bucket=model_bucket, loader=load_params
)
# state_dict = load_object_from_s3(
# s3_key=model_path, s3_bucket=model_bucket, loader=load_params
# )

state_dict = load_params(model_path)

for key in model:
if key in state_dict:
print("%s loaded" % key)
Expand Down Expand Up @@ -111,24 +112,6 @@ def _load_model(
return model, sampler


# NOTE (Sam): this is a version of the loading code from celery-bark that caches models.
def load_model(
cache, model_bucket, model_path, text_aligner, pitch_extractor, plbert, model_params
):
key = model_path.replace("/", "_")
if key not in cache:
cache[key] = _load_model(
model_bucket,
model_path,
text_aligner,
pitch_extractor,
plbert,
model_params,
)

return cache[key]


def load_phonemizer(language, cache):
if language not in cache:
cache[language] = phonemizer.backend.EspeakBackend(
Expand Down Expand Up @@ -274,32 +257,15 @@ def resize_array(input_array, new_size):
s3_key=config_path, s3_bucket=config_bucket, loader=load_config
)["model_params"]
cache = pylru.lrucache(1)
asr_config = load_object_from_s3(
s3_key="styletts2/asr/config.yml", s3_bucket=MODEL_BUCKET, loader=load_config
)
text_aligner = load_object_from_s3(
s3_key="styletts2/asr/epoch_00080.pth",
s3_bucket=MODEL_BUCKET,
loader=lambda path: load_asr_models(path, asr_config),
)
pitch_extractor = load_object_from_s3(
s3_key="styletts2/jdc/bst.t7", s3_bucket=MODEL_BUCKET, loader=load_f0_models
)
plbert_config = load_object_from_s3(
s3_key="styletts2/plbert/config.yml", s3_bucket=MODEL_BUCKET, loader=load_config
)
plbert = load_object_from_s3(
s3_key="styletts2/plbert/step_1000000.t7",
s3_bucket=MODEL_BUCKET,
loader=lambda x: load_plbert(plbert_config, x),
)

model_path = "styletts2/prototype_voice.pth"
model_bucket = "uberduck-models-us-west-2"
asr_config = load_config("models/asr_config.yml")
plbert_config = load_config("models/plbert_config.yml")

text_aligner = load_asr_models("models/text_aligner.pth", asr_config)
pitch_extractor = load_f0_models("models/pitch_extractor.t7")
plbert = load_plbert(plbert_config, "models/plbert.t7")
model, sampler = load_model(
cache=cache,
model_bucket=model_bucket,
model_path=model_path,
model_path="models/prototype_voice.pth",
text_aligner=text_aligner,
pitch_extractor=pitch_extractor,
plbert=plbert,
Expand Down

0 comments on commit d90c861

Please sign in to comment.