|
| 1 | +#!/usr/bin/env python3.0 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +import os |
| 5 | +import shutil |
| 6 | +import multiprocessing |
| 7 | +from pathlib import Path |
| 8 | +import torchaudio |
| 9 | +import random |
| 10 | +import glob |
| 11 | +import logging |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +import satools.script_utils as script_utils |
| 16 | +from satools.infer_helper import load_model |
| 17 | +from satools.utils.kaldi import load_wav_from_scp |
| 18 | + |
| 19 | +def copy_data_dir(dataset_path, output_path): |
| 20 | + # Copy utt2spk wav.scp and so on, but not the directories inside (may contains clear or anonymzied *.wav) |
| 21 | + os.makedirs(output_path, exist_ok=True) |
| 22 | + for p in glob.glob(str(Path(dataset_path) / '*'), recursive=False): |
| 23 | + if os.path.isfile(p): |
| 24 | + shutil.copy(p, output_path) |
| 25 | + |
| 26 | +class Wav(): # for f0 extraction |
| 27 | + def __init__(self, w): |
| 28 | + self.wav = w |
| 29 | + |
| 30 | +class Dataset(torch.utils.data.Dataset): |
| 31 | + def __init__(self, id_wavs, get_f0_func): |
| 32 | + self.all_wavs = list(id_wavs.values()) |
| 33 | + self.all_keys = list(id_wavs.keys()) |
| 34 | + self.get_f0_func = get_f0_func |
| 35 | + |
| 36 | + def __len__(self): |
| 37 | + return len(self.all_wavs) |
| 38 | + |
| 39 | + def __getitem__(self, index): |
| 40 | + audio, freq = load_wav_from_scp(str(self.all_wavs[index])) |
| 41 | + f0 = self.get_f0_func(Wav(audio)) |
| 42 | + return {"utid": self.all_keys[index], |
| 43 | + "audio": audio, |
| 44 | + "f0": f0, |
| 45 | + "freq": freq} |
| 46 | + |
| 47 | +def collate_fn(item_list): |
| 48 | + batch_size = len(item_list) |
| 49 | + |
| 50 | + data_list_audio = [i['audio'] for i in item_list] |
| 51 | + lengths_tensor_audio = torch.tensor([i.shape[-1] for i in data_list_audio]) |
| 52 | + max_len_audio = torch.max(lengths_tensor_audio).item() |
| 53 | + output_audio = torch.zeros([batch_size, max_len_audio]) |
| 54 | + for i in range(batch_size): |
| 55 | + cur = data_list_audio[i] |
| 56 | + cur_len = data_list_audio[i].shape[-1] |
| 57 | + output_audio[i, :cur_len] = cur.squeeze() |
| 58 | + |
| 59 | + data_list_f0 = [i['f0'] for i in item_list] |
| 60 | + lengths_tensor_f0 = torch.tensor([i.shape[-1] for i in data_list_f0]) |
| 61 | + max_len_f0 = torch.max(lengths_tensor_f0).item() |
| 62 | + output_f0 = torch.zeros([batch_size, max_len_f0]) |
| 63 | + for i in range(batch_size): |
| 64 | + cur = data_list_f0[i] |
| 65 | + cur_len = data_list_f0[i].shape[-1] |
| 66 | + output_f0[i, :cur_len] = cur.squeeze() |
| 67 | + |
| 68 | + utids = [i['utid'] for i in item_list] |
| 69 | + freqs = [i['freq'] for i in item_list] |
| 70 | + return output_audio, output_f0, lengths_tensor_audio, utids, freqs |
| 71 | + |
| 72 | +def process_data(dataset_path: str, target_selection_algorithm: str, wavscp: dict, settings: dict, progress): |
| 73 | + results_dir = settings.results_dir |
| 74 | + dataset_path = Path(str(dataset_path)) |
| 75 | + output_path = Path(str(dataset_path) + settings.new_datadir_suffix) |
| 76 | + device = settings.device |
| 77 | + batch_size = settings.batch_size |
| 78 | + |
| 79 | + copy_data_dir(dataset_path, output_path) |
| 80 | + results_dir = output_path / results_dir |
| 81 | + os.makedirs(results_dir, exist_ok = True) |
| 82 | + |
| 83 | + wav_scp = dataset_path / 'wav.scp' |
| 84 | + utt2spk = dataset_path / 'utt2spk' |
| 85 | + wav_scp_out = output_path / 'wav.scp' |
| 86 | + |
| 87 | + model = load_model(settings.model) |
| 88 | + model.to(device) |
| 89 | + model.eval() |
| 90 | + possible_targets = model.spk.copy() # For spk and utt target_selection_algorithm random choice |
| 91 | + |
| 92 | + source_utt2spk = script_utils.read_wav_scp(utt2spk) |
| 93 | + out_spk2target = {} # For spk target_selection_algorithm |
| 94 | + |
| 95 | + |
| 96 | + @torch.no_grad() |
| 97 | + def process_wav(utid, freq, audio, f0, original_len): |
| 98 | + |
| 99 | + freq = freq[0] # assume all freq = in same batch (and so dataset) |
| 100 | + audio = audio.to(device) |
| 101 | + |
| 102 | + # Anonymize function |
| 103 | + model.set_f0(f0.to(device)) # CPU extracted by Dataloader (num_workers) |
| 104 | + # Batch select target spks from the available model list depending on target_selection_algorithm |
| 105 | + target_spks = [] |
| 106 | + if target_selection_algorithm == "constant": # The best way/most secure to evaluate privacy when applied to all dataset (train included) |
| 107 | + target_constant_spkid = settings.target_constant_spkid # For constant target_selection_algorithm |
| 108 | + target_spks = [target_constant_spkid]*audio.shape[0] |
| 109 | + elif target_selection_algorithm == "bad_for_evaluation": |
| 110 | + # This target selection algorithm is bad for evaluation as it does |
| 111 | + # not generate suitable training data for the ASV eval training |
| 112 | + # procedure. Use it with caution. |
| 113 | + for ut in utid: |
| 114 | + source_spk = source_utt2spk[ut] |
| 115 | + if source_spk not in out_spk2target: |
| 116 | + out_spk2target[source_spk] = random.sample(possible_targets, 2) |
| 117 | + target_spks.append(random.choice(out_spk2target[source_spk])) |
| 118 | + elif target_selection_algorithm == "random_per_utt": |
| 119 | + target_spks = [] |
| 120 | + for ut in utid: |
| 121 | + target_spks.append(random.choice(possible_targets)) |
| 122 | + elif target_selection_algorithm == "random_per_spk_uniq": |
| 123 | + for ut in utid: |
| 124 | + source_spk = source_utt2spk[ut] |
| 125 | + if source_spk not in out_spk2target: |
| 126 | + out_spk2target[source_spk] = random.choice(possible_targets) |
| 127 | + # Remove target spk: size of possible source spk to anonymize == len(possible_targets) (==247) or you need to add spk target overlap) |
| 128 | + possible_targets.remove(out_spk2target[source_spk]) |
| 129 | + target_spks.append(out_spk2target[source_spk]) |
| 130 | + elif target_selection_algorithm == "random_per_spk": |
| 131 | + for ut in utid: |
| 132 | + source_spk = source_utt2spk[ut] |
| 133 | + if source_spk not in out_spk2target: |
| 134 | + out_spk2target[source_spk] = random.choice(possible_targets) |
| 135 | + target_spks.append(out_spk2target[source_spk]) |
| 136 | + else: |
| 137 | + raise ValueError(f"{target_selection_algorithm} not implemented") |
| 138 | + # Batch conversion |
| 139 | + wav_conv = model.convert(audio, target=target_spks) |
| 140 | + wav_conv = wav_conv.cpu() |
| 141 | + |
| 142 | + def parallel_write(): |
| 143 | + for i in range(wav_conv.shape[0]): |
| 144 | + wav = wav_conv[i] |
| 145 | + if len(wav.shape) == 1: |
| 146 | + wav = wav.unsqueeze(0) # batch == 1 -> len(dst) % batch == 1 |
| 147 | + wav = wav[:, :original_len[i]] |
| 148 | + # write to buffer |
| 149 | + u = utid[i] |
| 150 | + output_file = results_dir / f'{u}.wav' |
| 151 | + torchaudio.save(str(output_file), wav, freq, encoding='PCM_S', bits_per_sample=16) |
| 152 | + p = multiprocessing.Process(target=parallel_write, args=()) |
| 153 | + p.start() |
| 154 | + return p |
| 155 | + |
| 156 | + nj = settings.data_loader_nj |
| 157 | + nj = min(nj, 18) |
| 158 | + p = None |
| 159 | + |
| 160 | + with open(wav_scp_out, 'wt', encoding='utf-8') as writer: |
| 161 | + filtered_wavs = {} |
| 162 | + for u, file in wavscp.items(): |
| 163 | + output_file = results_dir / f'{u}.wav' |
| 164 | + filtered_wavs[u] = file |
| 165 | + |
| 166 | + data_loader = torch.utils.data.DataLoader(Dataset(filtered_wavs, model.get_f0), batch_size=batch_size, num_workers=nj, collate_fn=collate_fn) |
| 167 | + for audio, f0, original_len, utid, freq in data_loader: |
| 168 | + p = process_wav(utid, freq, audio, f0, original_len) |
| 169 | + for u in utid: |
| 170 | + output_file = results_dir / f'{u}.wav' |
| 171 | + writer.writelines(f"{u} {output_file}\n") |
| 172 | + with progress.get_lock(): |
| 173 | + progress.value += batch_size |
| 174 | + if device.startswith("cuda"): |
| 175 | + torch.cuda.empty_cache() |
| 176 | + # wait for last p to write the anonymized audios |
| 177 | + if p: |
| 178 | + p.join() |
0 commit comments