Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows fixes #84

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 138 additions & 63 deletions fam/llm/enhancers.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,177 @@
import os
import logging
from abc import ABC
from typing import Literal, Optional

from df.enhance import enhance, init_df, load_audio, save_audio
from pydub import AudioSegment

# Configure basic logging settings for the application.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def convert_to_wav(input_file: str, output_file: str):
"""Convert an audio file to WAV format
"""
Converts an audio file to WAV format.

Args:
input_file (str): path to input audio file
output_file (str): path to output WAV file
input_file (str): Path to the input audio file.
output_file (str): Path where the output WAV file will be saved.

This function uses pydub.AudioSegment to read an audio file in its original format and export it as a WAV file.
"""
# Detect the format of the input file
format = input_file.split(".")[-1].lower()

# Read the audio file
audio = AudioSegment.from_file(input_file, format=format)

# Export as WAV
audio.export(output_file, format="wav")

try:
logger.info("Starting convert_to_wav")
# Extract the file format from the input file name.
format = input_file.split(".")[-1].lower()
# Load the audio file using its format for proper decoding.
audio = AudioSegment.from_file(input_file, format=format)
# Export the audio data to a new file in WAV format.
audio.export(output_file, format="wav")
logger.info("Finished convert_to_wav")
except Exception as e:
logger.error(f"Error in convert_to_wav: {e}")

def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
"""Generate the output file path
"""
Generates a path for the output file with an added tag and optional custom extension.

Args:
audio_file (str): path to input audio file
tag (str): tag to append to the output file name
ext (str, optional): extension of the output file. Defaults to None.
audio_file (str): Original path of the audio file.
tag (str): Tag to append to the filename (before the extension).
ext (Optional[str]): Optional custom extension for the output file. Uses original extension if None.

Returns:
str: path to output file
str: Path for the output file with the specified tag and extension.
"""

directory = "./enhanced"
# Get the name of the input file
filename = os.path.basename(audio_file)

# Get the name of the input file without the extension
filename_without_extension = os.path.splitext(filename)[0]

# Get the extension of the input file
extension = ext or os.path.splitext(filename)[1]

# Generate the output file path
output_file = os.path.join(directory, filename_without_extension + tag + extension)

return output_file

try:
logger.info("Starting make_output_file_path")
# Define the directory to save enhanced audio files.
directory = "./enhanced"
# Extract the filename from the original audio file path.
filename = os.path.basename(audio_file)
# Separate the filename from its extension.
filename_without_extension = os.path.splitext(filename)[0]
# Use the provided extension or fall back to the original extension.
extension = ext or os.path.splitext(filename)[1]
# Construct the output file path with the added tag and extension.
output_file = os.path.join(directory, filename_without_extension + tag + extension)
logger.info("Finished make_output_file_path")
return output_file
except Exception as e:
logger.error(f"Error in make_output_file_path: {e}")

class BaseEnhancer(ABC):
"""Base class for audio enhancers"""

"""
Abstract base class for audio enhancers. Implementations should override the __call__ method.
"""
def __init__(self, *args, **kwargs):
raise NotImplementedError
try:
logger.info("Initializing BaseEnhancer")
# Abstract classes cannot be instantiated.
raise NotImplementedError
except Exception as e:
logger.error(f"Error in BaseEnhancer.__init__: {e}")

def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
raise NotImplementedError
"""
Enhances an audio file. This method must be implemented by subclasses.

def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
output_file = make_output_file_path(audio_file, tag, ext=ext)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
return output_file
Args:
audio_file (str): Path to the input audio file.
output_file (Optional[str]): Optional path to save the enhanced audio file.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""
try:
raise NotImplementedError
except Exception as e:
logger.error(f"Error in BaseEnhancer.__call__: {e}")

def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
"""
Generates a path for the output file using the specified tag and extension.

Args:
audio_file (str): Path to the original audio file.
tag (str): Tag to append to the filename.
ext (Optional[str]): Optional custom extension for the output file.

Returns:
str: Path for the output file with the tag and extension.
"""
try:
logger.info("Starting BaseEnhancer.get_output_file")
# Generate the output file path with the specified tag and extension.
output_file = make_output_file_path(audio_file, tag, ext=ext)
# Ensure the directory for the output file exists.
os.makedirs(os.path.dirname(output_file), exist_ok=True)
logger.info("Finished BaseEnhancer.get_output_file")
return output_file
except Exception as e:
logger.error(f"Error in BaseEnhancer.get_output_file: {e}")

class DFEnhancer(BaseEnhancer):
"""
Enhancer class using the "df" enhancement algorithm. Inherits from BaseEnhancer.
"""
def __init__(self, *args, **kwargs):
self.model, self.df_state, _ = init_df()
try:
logger.info("Starting DFEnhancer.__init__")
# Initialize the enhancement model and state.
self.model, self.df_state, _ = init_df()
logger.info("Finished DFEnhancer.__init__")
except Exception as e:
logger.error(f"Error in DFEnhancer.__init__: {e}")

def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
output_file = output_file or self.get_output_file(audio_file, "_df")

audio, _ = load_audio(audio_file, sr=self.df_state.sr())

enhanced = enhance(self.model, self.df_state, audio)

save_audio(output_file, enhanced, self.df_state.sr())

return output_file

"""
Enhances an audio file using the "df" enhancement algorithm.

Args:
audio_file (str): Path to the input audio file.
output_file (Optional[str]): Optional path to save the enhanced audio file.

Returns:
str: Path to the enhanced audio file.
"""
try:
logger.info("Starting DFEnhancer.__call__")
# Determine the output file path if not provided.
output_file = output_file or self.get_output_file(audio_file, "_df")
# Load the audio file and enhance it using the "df" algorithm.
audio, _ = load_audio(audio_file, sr=self.df_state.sr())
enhanced = enhance(self.model, self.df_state, audio)
# Save the enhanced audio to the specified output file.
save_audio(output_file, enhanced, self.df_state.sr())
logger.info("Finished DFEnhancer.__call__")
return output_file
except Exception as e:
logger.error(f"Error in DFEnhancer.__call__: {e}")

def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""Get an audio enhancer
"""
Factory function to get an enhancer instance based on the enhancer name.

Args:
enhancer_name (Literal["df"]): name of the audio enhancer

Raises:
ValueError: if the enhancer name is not recognised
enhancer_name (Literal["df"]): Name of the enhancer to instantiate.

Returns:
BaseEnhancer: audio enhancer
"""
BaseEnhancer: Instance of the specified enhancer.

if enhancer_name == "df":
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
Raises:
ValueError: If an unknown enhancer name is provided.
"""
try:
logger.info("Starting get_enhancer")
# Instantiate the appropriate enhancer based on the provided name.
if enhancer_name == "df":
enhancer = DFEnhancer()
else:
# Raise an error for unsupported enhancer names.
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
logger.info("Finished get_enhancer")
return enhancer
except Exception as e:
logger.error(f"Error in get_enhancer: {e}")
31 changes: 24 additions & 7 deletions fam/llm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
from fam.quantiser.text.tokenise import TrainedBPETokeniser


import time
import logging
import traceback
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class InferenceConfig:
ckpt_path: str # path to checkpoint
Expand Down Expand Up @@ -458,21 +463,25 @@ def _sample_utterance_batch(
temperature: Optional[float],
batch_size: int = 128,
) -> List[str]:
logger.info("Starting _sample_utterance_batch")

speaker_embs = []
refs = spk_cond_paths.copy()

# multithreaded loop to cache all the files
# Multithreaded loop to cache all the files
logger.info("Caching speaker reference files")
spk_cond_paths = tqdm.contrib.concurrent.thread_map(
get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
)

logger.info("Calculating speaker embeddings")
for i, (text, spk_cond_path) in tqdm.tqdm(
enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
):
texts[i] = normalize_text(text)
speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)

logger.info("Processing with first stage model")
b_speaker_embs = torch.cat(speaker_embs, dim=0)
b_tokens = first_stage_model(
texts=texts,
Expand All @@ -485,7 +494,7 @@ def _sample_utterance_batch(
max_new_tokens=max_new_tokens,
)

# TODO: set batch size for second stage model!
logger.info("Processing with second stage model")
wav_files = second_stage_model(
texts=texts,
encodec_tokens=b_tokens,
Expand All @@ -498,25 +507,33 @@ def _sample_utterance_batch(
max_new_tokens=None,
)

logger.info("Post-processing generated WAV files")
for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
if wav_file is None:
continue

with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as enhanced_tmp:
if enhancer is not None:
logger.info(f"Enhancing audio for {wav_file}")
enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
# copy enhanced_tmp.name back to wav_file
print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")

logger.info(f"Copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}")
src = os.path.abspath(enhanced_tmp.name)
dst = os.path.abspath(str(wav_file) + ".wav")
time.sleep(1) # Sleep for 1 second to ensure file is released
shutil.copy2(src, dst)

logger.info(f"Saving result metadata for {wav_file}")
save_result_metadata(
wav_file,
ref_name,
text,
first_stage_ckpt_path,
second_stage_ckpt_path,
)

logger.info("Finished _sample_utterance_batch")
return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]


Expand Down
Loading