Skip to content
This repository was archived by the owner on Aug 11, 2025. It is now read-only.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,7 @@ dmypy.json

# Pyre type checker
.pyre/
.ljspeech

# custom
outputs
6 changes: 3 additions & 3 deletions config/training_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
paths:
# PATHS: change accordingly
wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs
metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory
wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs
metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory
log_directory: '/path/to/logs_directory' # weights and logs are stored here
train_data_directory: 'transformer_tts_data' # training data is stored here

Expand Down Expand Up @@ -51,7 +51,7 @@ audio_settings:

text_settings:
# TOKENIZER
phoneme_language: 'en-us'
phoneme_language: 'en-us' # use 'de' for german, use false if the input data is already phonemized
with_stress: True # use stress symbols in phonemization
model_breathing: false # add a token for the initial breathing

Expand Down
154 changes: 81 additions & 73 deletions create_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from data.datasets import DataReader
from utils.training_config_manager import TrainingConfigManager
from data.audio import Audio
from data.text.symbols import _alphabet
from data.text.symbols import _alphabet, all_phonemes

np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--skip_phonemes', action='store_true')
parser.add_argument('--skip_mels', action='store_true')

args = parser.parse_args()
Expand Down Expand Up @@ -97,84 +96,93 @@ def process_pitches(item: tuple):
total_wav_len = total_mel_len * audio.config['hop_length']
summary_manager.display_scalar('Total duration (hours)',
scalar_value=total_wav_len / audio.config['sampling_rate'] / 60. ** 2)


if not args.skip_phonemes:
remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb'))
phonemized_metadata_path = cm.phonemized_metadata_path
train_metadata_path = cm.train_metadata_path
test_metadata_path = cm.valid_metadata_path
print(f'\nReading metadata from {metadatareader.metadata_path}')
print(f'\nFound {len(metadatareader.filenames)} lines.')
def get_short_files(phonemized=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hints missing.

if not phonemized:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not phonemized:
symbol_list = all_phonemes if phonemized else _alphabet

symbol_list = _alphabet
else:
symbol_list = all_phonemes

filter_metadata = []
for fname in cross_file_ids:
item = metadatareader.text_dict[fname]
non_p = [c for c in item if c in _alphabet]
non_p = [c for c in item if c in symbol_list]
if len(non_p) < 1:
filter_metadata.append(fname)
if len(filter_metadata) > 0:
print(f'Removing {len(filter_metadata)} suspiciously short line(s):')
for fname in filter_metadata:
print(f'{fname}: {metadatareader.text_dict[fname]}')
print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.')
remove_files += filter_metadata
metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files]
metadata_len = len(metadata_file_ids)
sample_items = np.random.choice(metadata_file_ids, 5)
test_len = cm.config['n_test']
train_len = metadata_len - test_len
print(f'\nMetadata contains {metadata_len} lines.')
print(f'\nFiles will be stored under {cm.data_dir}')
print(f' - all: {phonemized_metadata_path}')
print(f' - {train_len} training lines: {train_metadata_path}')
print(f' - {test_len} validation lines: {test_metadata_path}')

print('\nMetadata samples:')
for i in sample_items:
print(f'{i}:{metadatareader.text_dict[i]}')
summary_manager.add_text(f'{i}/text', text=metadatareader.text_dict[i])

# run cleaner on raw text
text_proc = TextToTokens.default(cm.config['phoneme_language'], add_start_end=False,
with_stress=cm.config['with_stress'], model_breathing=cm.config['model_breathing'],
njobs=1)


def process_phonemes(file_id):
text = metadatareader.text_dict[file_id]
try:
phon = text_proc.phonemizer(text)
except Exception as e:
print(f'{e}\nFile id {file_id}')
raise BrokenPipeError
return (file_id, phon)


print('\nPHONEMIZING')
phonemized_data = {}
phon_iter = p_uimap(process_phonemes, metadata_file_ids)
for (file_id, phonemes) in phon_iter:
phonemized_data.update({file_id: phonemes})

print('\nPhonemized metadata samples:')
for i in sample_items:
print(f'{i}:{phonemized_data[i]}')
summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i])

new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()]
shuffled_metadata = np.random.permutation(new_metadata)
train_metadata = shuffled_metadata[0:train_len]
test_metadata = shuffled_metadata[-test_len:]

with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(new_metadata)
with open(train_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(train_metadata)
with open(test_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(test_metadata)
# some checks
assert metadata_len == len(set(list(phonemized_data.keys()))), \
f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.'
assert len(train_metadata) + len(test_metadata) == metadata_len, \
f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})'

print('\nDone')
return filter_metadata


remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unclosed fileio

phonemized_metadata_path = cm.phonemized_metadata_path
train_metadata_path = cm.train_metadata_path
test_metadata_path = cm.valid_metadata_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent naming of train and validation

print(f'\nReading metadata from {metadatareader.metadata_path}')
print(f'\nFound {len(metadatareader.filenames)} lines.')

filter_metadata = get_short_files(phonemized=not cm.config['phoneme_language'])
remove_files += filter_metadata
print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.')
metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files]
metadata_len = len(metadata_file_ids)
sample_items = np.random.choice(metadata_file_ids, 5)
test_len = cm.config['n_test']
train_len = metadata_len - test_len
print(f'\nMetadata contains {metadata_len} lines.')
print(f'\nFiles will be stored under {cm.data_dir}')
print(f' - all: {phonemized_metadata_path}')
print(f' - {train_len} training lines: {train_metadata_path}')
print(f' - {test_len} validation lines: {test_metadata_path}')

# run cleaner on raw text
text_proc = TextToTokens.default(cm.config['phoneme_language'],
add_start_end=False,
with_stress=cm.config['with_stress'],
model_breathing=cm.config['model_breathing'],
njobs=1)


def process_phonemes(file_id):
text = metadatareader.text_dict[file_id]
try:
phon = text_proc.phonemizer(text)
except Exception as e:
print(f'{e}\nFile id {file_id}')
raise BrokenPipeError
return (file_id, phon)


print('\nPHONEMIZING')
phonemized_data = {}
phon_iter = p_uimap(process_phonemes, metadata_file_ids)
for (file_id, phonemes) in phon_iter:
phonemized_data.update({file_id: phonemes})

print('\nPhonemized metadata samples:')
for i in sample_items:
print(f'{i}:{phonemized_data[i]}')
summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i])

new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()]
shuffled_metadata = np.random.permutation(new_metadata)
train_metadata = shuffled_metadata[0:train_len]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
train_metadata = shuffled_metadata[0:train_len]
train_metadata = shuffled_metadata[0:-test_len]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its safer to just use the test_len, also it saves a couple of lines

test_metadata = shuffled_metadata[-test_len:]

with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(new_metadata)
with open(train_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(train_metadata)
with open(test_metadata_path, 'w+', encoding='utf-8') as file:
file.writelines(test_metadata)

# some checks
assert metadata_len == len(set(list(phonemized_data.keys()))), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert metadata_len == len(set(list(phonemized_data.keys()))), \
assert metadata_len == len(phonemized_data)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the other dict keys.

f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.'
assert len(train_metadata) + len(test_metadata) == metadata_len, \
f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})'

print('\n Done')
3 changes: 2 additions & 1 deletion data/text/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
_suprasegmentals = 'ˈˌːˑ'
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
_diacrilics = 'ɚ˞ɫ'
_extra_phons = ['g', 'ɝ', '̃', '̍', '̥', '̩', '̯', '͡'] # some extra symbols from wiktionary ipa annotations
_phonemes = sorted(list(
_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) + _extra_phons
_punctuations = '!,-.:;? \'()'
_alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzäüößÄÖÜ'

Expand Down
25 changes: 14 additions & 11 deletions data/text/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, start_token='>', end_token='<', pad_token='/', add_start_end=
self.breathing_token = '@'
self.idx_to_token[self.breathing_token_index] = self.breathing_token
self.token_to_idx[self.breathing_token] = [self.breathing_token_index]

def __call__(self, sentence: str) -> list:
sequence = [self.token_to_idx[c] for c in sentence] # No filtering: text should only contain known chars.
sequence = [item for items in sequence for item in items]
Expand Down Expand Up @@ -62,16 +62,19 @@ def __call__(self, text: Union[str, list], with_stress=None, njobs=None, languag
njobs = njobs or self.njobs
with_stress = with_stress or self.with_stress
# phonemizer does not like hyphens.
text = self._preprocess(text)
phonemes = phonemize(text,
language=language,
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=with_stress,
punctuation_marks=self.punctuation,
njobs=njobs,
language_switch='remove-flags')
if language:
text = self._preprocess(text)
phonemes = phonemize(text,
language=language,
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=with_stress,
punctuation_marks=self.punctuation,
njobs=njobs,
language_switch='remove-flags')
else:
phonemes = text
return self._postprocess(phonemes)

def _preprocess_string(self, text: str):
Expand Down
2 changes: 1 addition & 1 deletion predict_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
outdir = outdir / 'outputs' / f'{fname}'
outdir.mkdir(exist_ok=True, parents=True)
output_path = (outdir / file_name).with_suffix('.wav')
audio = Audio.from_config(model.config)
audio = Audio.from_config(model.config)
print(f'Output wav under {output_path.parent}')
wavs = []
for i, text_line in enumerate(text):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_char_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def test_tokenizer(self):
tokenizer = Tokenizer(alphabet=list('ab c'))
self.assertEqual(5, tokenizer.start_token_index)
self.assertEqual(6, tokenizer.end_token_index)
self.assertEqual(7, tokenizer.vocab_size)
seq = tokenizer('a b d')
self.assertEqual([5, 1, 3, 2, 3, 6], seq)
seq = np.array([5, 1, 3, 2, 8, 6])
self.assertEqual(8, tokenizer.vocab_size)

seq = tokenizer('a b c')
self.assertEqual([5, 7, 2, 1, 7, 3, 1, 7, 4, 6], seq)

seq = np.array([5, 2, 1, 3, 6])
seq = tf.convert_to_tensor(seq)
text = tokenizer.decode(seq)
self.assertEqual('>a b<', text)