diff --git a/.gitignore b/.gitignore index 6ad59d2..68a7689 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,7 @@ dmypy.json # Pyre type checker .pyre/ +.ljspeech + +# custom +outputs \ No newline at end of file diff --git a/config/training_config.yaml b/config/training_config.yaml index 2a12af0..0368c38 100644 --- a/config/training_config.yaml +++ b/config/training_config.yaml @@ -1,7 +1,7 @@ paths: # PATHS: change accordingly - wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs - metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory + wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs + metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory log_directory: '/path/to/logs_directory' # weights and logs are stored here train_data_directory: 'transformer_tts_data' # training data is stored here @@ -51,7 +51,7 @@ audio_settings: text_settings: # TOKENIZER - phoneme_language: 'en-us' + phoneme_language: 'en-us' # use 'de' for german, use false if the input data is already phonemized with_stress: True # use stress symbols in phonemization model_breathing: false # add a token for the initial breathing diff --git a/create_training_data.py b/create_training_data.py index bacd22e..00cf1b6 100644 --- a/create_training_data.py +++ b/create_training_data.py @@ -10,13 +10,12 @@ from data.datasets import DataReader from utils.training_config_manager import TrainingConfigManager from data.audio import Audio -from data.text.symbols import _alphabet +from data.text.symbols import _alphabet, all_phonemes np.random.seed(42) parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, required=True) -parser.add_argument('--skip_phonemes', action='store_true') parser.add_argument('--skip_mels', action='store_true') args = parser.parse_args() @@ -97,84 +96,93 @@ def process_pitches(item: tuple): total_wav_len = total_mel_len * audio.config['hop_length'] summary_manager.display_scalar('Total duration (hours)', scalar_value=total_wav_len / audio.config['sampling_rate'] / 60. ** 2) + -if not args.skip_phonemes: - remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb')) - phonemized_metadata_path = cm.phonemized_metadata_path - train_metadata_path = cm.train_metadata_path - test_metadata_path = cm.valid_metadata_path - print(f'\nReading metadata from {metadatareader.metadata_path}') - print(f'\nFound {len(metadatareader.filenames)} lines.') +def get_short_files(phonemized=False): + if not phonemized: + symbol_list = _alphabet + else: + symbol_list = all_phonemes + filter_metadata = [] for fname in cross_file_ids: item = metadatareader.text_dict[fname] - non_p = [c for c in item if c in _alphabet] + non_p = [c for c in item if c in symbol_list] if len(non_p) < 1: filter_metadata.append(fname) if len(filter_metadata) > 0: print(f'Removing {len(filter_metadata)} suspiciously short line(s):') for fname in filter_metadata: print(f'{fname}: {metadatareader.text_dict[fname]}') - print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.') - remove_files += filter_metadata - metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files] - metadata_len = len(metadata_file_ids) - sample_items = np.random.choice(metadata_file_ids, 5) - test_len = cm.config['n_test'] - train_len = metadata_len - test_len - print(f'\nMetadata contains {metadata_len} lines.') - print(f'\nFiles will be stored under {cm.data_dir}') - print(f' - all: {phonemized_metadata_path}') - print(f' - {train_len} training lines: {train_metadata_path}') - print(f' - {test_len} validation lines: {test_metadata_path}') - - print('\nMetadata samples:') - for i in sample_items: - print(f'{i}:{metadatareader.text_dict[i]}') - summary_manager.add_text(f'{i}/text', text=metadatareader.text_dict[i]) - - # run cleaner on raw text - text_proc = TextToTokens.default(cm.config['phoneme_language'], add_start_end=False, - with_stress=cm.config['with_stress'], model_breathing=cm.config['model_breathing'], - njobs=1) - - - def process_phonemes(file_id): - text = metadatareader.text_dict[file_id] - try: - phon = text_proc.phonemizer(text) - except Exception as e: - print(f'{e}\nFile id {file_id}') - raise BrokenPipeError - return (file_id, phon) - - - print('\nPHONEMIZING') - phonemized_data = {} - phon_iter = p_uimap(process_phonemes, metadata_file_ids) - for (file_id, phonemes) in phon_iter: - phonemized_data.update({file_id: phonemes}) - - print('\nPhonemized metadata samples:') - for i in sample_items: - print(f'{i}:{phonemized_data[i]}') - summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i]) - - new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()] - shuffled_metadata = np.random.permutation(new_metadata) - train_metadata = shuffled_metadata[0:train_len] - test_metadata = shuffled_metadata[-test_len:] - - with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file: - file.writelines(new_metadata) - with open(train_metadata_path, 'w+', encoding='utf-8') as file: - file.writelines(train_metadata) - with open(test_metadata_path, 'w+', encoding='utf-8') as file: - file.writelines(test_metadata) - # some checks - assert metadata_len == len(set(list(phonemized_data.keys()))), \ - f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.' - assert len(train_metadata) + len(test_metadata) == metadata_len, \ - f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})' - -print('\nDone') + return filter_metadata + + +remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb')) +phonemized_metadata_path = cm.phonemized_metadata_path +train_metadata_path = cm.train_metadata_path +test_metadata_path = cm.valid_metadata_path +print(f'\nReading metadata from {metadatareader.metadata_path}') +print(f'\nFound {len(metadatareader.filenames)} lines.') + +filter_metadata = get_short_files(phonemized=not cm.config['phoneme_language']) +remove_files += filter_metadata +print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.') +metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files] +metadata_len = len(metadata_file_ids) +sample_items = np.random.choice(metadata_file_ids, 5) +test_len = cm.config['n_test'] +train_len = metadata_len - test_len +print(f'\nMetadata contains {metadata_len} lines.') +print(f'\nFiles will be stored under {cm.data_dir}') +print(f' - all: {phonemized_metadata_path}') +print(f' - {train_len} training lines: {train_metadata_path}') +print(f' - {test_len} validation lines: {test_metadata_path}') + +# run cleaner on raw text +text_proc = TextToTokens.default(cm.config['phoneme_language'], + add_start_end=False, + with_stress=cm.config['with_stress'], + model_breathing=cm.config['model_breathing'], + njobs=1) + + +def process_phonemes(file_id): + text = metadatareader.text_dict[file_id] + try: + phon = text_proc.phonemizer(text) + except Exception as e: + print(f'{e}\nFile id {file_id}') + raise BrokenPipeError + return (file_id, phon) + + +print('\nPHONEMIZING') +phonemized_data = {} +phon_iter = p_uimap(process_phonemes, metadata_file_ids) +for (file_id, phonemes) in phon_iter: + phonemized_data.update({file_id: phonemes}) + +print('\nPhonemized metadata samples:') +for i in sample_items: + print(f'{i}:{phonemized_data[i]}') + summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i]) + +new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()] +shuffled_metadata = np.random.permutation(new_metadata) +train_metadata = shuffled_metadata[0:train_len] +test_metadata = shuffled_metadata[-test_len:] + +with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file: + file.writelines(new_metadata) +with open(train_metadata_path, 'w+', encoding='utf-8') as file: + file.writelines(train_metadata) +with open(test_metadata_path, 'w+', encoding='utf-8') as file: + file.writelines(test_metadata) + +# some checks +assert metadata_len == len(set(list(phonemized_data.keys()))), \ + f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.' +assert len(train_metadata) + len(test_metadata) == metadata_len, \ + f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})' + +print('\n Done') diff --git a/data/text/symbols.py b/data/text/symbols.py index 809e4ec..9eb233d 100644 --- a/data/text/symbols.py +++ b/data/text/symbols.py @@ -4,8 +4,9 @@ _suprasegmentals = 'ˈˌːˑ' _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' +_extra_phons = ['g', 'ɝ', '̃', '̍', '̥', '̩', '̯', '͡'] # some extra symbols from wiktionary ipa annotations _phonemes = sorted(list( - _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) + _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) + _extra_phons _punctuations = '!,-.:;? \'()' _alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzäüößÄÖÜ' diff --git a/data/text/tokenizer.py b/data/text/tokenizer.py index 98a176c..5b98892 100644 --- a/data/text/tokenizer.py +++ b/data/text/tokenizer.py @@ -33,7 +33,7 @@ def __init__(self, start_token='>', end_token='<', pad_token='/', add_start_end= self.breathing_token = '@' self.idx_to_token[self.breathing_token_index] = self.breathing_token self.token_to_idx[self.breathing_token] = [self.breathing_token_index] - + def __call__(self, sentence: str) -> list: sequence = [self.token_to_idx[c] for c in sentence] # No filtering: text should only contain known chars. sequence = [item for items in sequence for item in items] @@ -62,16 +62,19 @@ def __call__(self, text: Union[str, list], with_stress=None, njobs=None, languag njobs = njobs or self.njobs with_stress = with_stress or self.with_stress # phonemizer does not like hyphens. - text = self._preprocess(text) - phonemes = phonemize(text, - language=language, - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=with_stress, - punctuation_marks=self.punctuation, - njobs=njobs, - language_switch='remove-flags') + if language: + text = self._preprocess(text) + phonemes = phonemize(text, + language=language, + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=with_stress, + punctuation_marks=self.punctuation, + njobs=njobs, + language_switch='remove-flags') + else: + phonemes = text return self._postprocess(phonemes) def _preprocess_string(self, text: str): diff --git a/predict_tts.py b/predict_tts.py index 5063e7a..feb4ebd 100644 --- a/predict_tts.py +++ b/predict_tts.py @@ -42,7 +42,7 @@ outdir = outdir / 'outputs' / f'{fname}' outdir.mkdir(exist_ok=True, parents=True) output_path = (outdir / file_name).with_suffix('.wav') - audio = Audio.from_config(model.config) + audio = Audio.from_config(model.config) print(f'Output wav under {output_path.parent}') wavs = [] for i, text_line in enumerate(text): diff --git a/tests/test_char_tokenizer.py b/tests/test_char_tokenizer.py index a9ab2fa..e52f06d 100644 --- a/tests/test_char_tokenizer.py +++ b/tests/test_char_tokenizer.py @@ -12,12 +12,12 @@ def test_tokenizer(self): tokenizer = Tokenizer(alphabet=list('ab c')) self.assertEqual(5, tokenizer.start_token_index) self.assertEqual(6, tokenizer.end_token_index) - self.assertEqual(7, tokenizer.vocab_size) - - seq = tokenizer('a b d') - self.assertEqual([5, 1, 3, 2, 3, 6], seq) - - seq = np.array([5, 1, 3, 2, 8, 6]) + self.assertEqual(8, tokenizer.vocab_size) + + seq = tokenizer('a b c') + self.assertEqual([5, 7, 2, 1, 7, 3, 1, 7, 4, 6], seq) + + seq = np.array([5, 2, 1, 3, 6]) seq = tf.convert_to_tensor(seq) text = tokenizer.decode(seq) self.assertEqual('>a b<', text)