From 59ef28d70833913cd1163b14b4ef373b0e2725a5 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 26 Jun 2024 23:53:17 +0200 Subject: [PATCH] build: move umap-learn into optional notebook dependencies Except for notebooks, it's only used to show embedding plots during speaker encoder training, in which case a warning is now shown to install it. --- TTS/bin/train_encoder.py | 14 +++++++++----- TTS/encoder/utils/visual.py | 5 ++++- pyproject.toml | 3 +-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 49b450cf82..ba03c42b6d 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -6,6 +6,7 @@ import sys import time import traceback +import warnings import torch from torch.utils.data import DataLoader @@ -116,11 +117,14 @@ def evaluation(model, criterion, data_loader, global_step): eval_avg_loss = eval_loss / len(data_loader) # save stats dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) - # plot the last batch in the evaluation - figures = { - "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), - } - dashboard_logger.eval_figures(global_step, figures) + try: + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + except ImportError: + warnings.warn("Install the `umap-learn` package to see embedding plots.") return eval_avg_loss diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py index 6575b86ec2..bfe40605df 100644 --- a/TTS/encoder/utils/visual.py +++ b/TTS/encoder/utils/visual.py @@ -1,7 +1,6 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np -import umap matplotlib.use("Agg") @@ -30,6 +29,10 @@ def plot_embeddings(embeddings, num_classes_in_batch): + try: + import umap + except ImportError as e: + raise ImportError("Package not installed: umap-learn") from e num_utter_per_class = embeddings.shape[0] // num_classes_in_batch # if necessary get just the first 10 classes diff --git a/pyproject.toml b/pyproject.toml index dad0d5ed0d..93486ff03a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,8 +58,6 @@ dependencies = [ "packaging>=23.1", # Inference "pysbd>=0.3.4", - # Notebooks - "umap-learn>=0.5.1", # Training "matplotlib>=3.7.0", # Coqui stack @@ -100,6 +98,7 @@ docs = [ notebooks = [ "bokeh==1.4.0", "pandas>=1.4,<2.0", + "umap-learn>=0.5.1", ] # For running the TTS server server = ["flask>=2.0.1"]