Skip to content

Commit 16e6c6f

Browse files
committed
readme and bin
1 parent 88d7a54 commit 16e6c6f

File tree

4 files changed

+283
-0
lines changed

4 files changed

+283
-0
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ wav_conv = model.convert(torch.rand((1, 77040)), target="1069")
5050
asr_bn = model.get_bn(torch.rand((1, 77040))) # (ASR-BN extraction for disentangled linguistic features (best with hifigan_bn_tdnnf_wav2vec2_vq_48_v1))
5151
```
5252

53+
## Anonymize bin
54+
Once the install.sh script is run, (`INSTALL_KALDI=false` can be set for faster installation), you will
55+
have access to the [`./satools/satools/bin/anonymize`](./satools/satools/bin/anonymize) bin in your path that you can use together
56+
with a config (example: [here](./egs/vc/libritts/configs/anon_any_to_one_for_train)) to anonymize a kaldi like directory.
57+
5358
## Quick JIT anonymization example
5459

5560
This version does not rely on any dependencies using [TorchScript](https://pytorch.org/docs/stable/jit.html).

satools/satools/bin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import pipeline

satools/satools/bin/anonymize

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
3+
description = """
4+
This script anonymize a kaldi/wav.scp formated dataset
5+
It takes a config file and a directory
6+
"""
7+
8+
import os
9+
os.environ["SA_JIT_TWEAK"] = "true"
10+
import sys
11+
import time
12+
from dataclasses import dataclass
13+
import configparser
14+
import argparse
15+
import logging
16+
17+
from multiprocessing import Process, Value
18+
from tqdm import tqdm
19+
20+
import satools.script_utils as script_utils
21+
22+
@dataclass
23+
class Pipeline(script_utils.ConfigParser):
24+
model: str = "large"
25+
f0_modification: str = ""
26+
target_selection_algorithm: str = "?"
27+
target_constant_spkid: str = "?"
28+
results_dir: int = "wav" # output of anonymize wavs ./data/XXXX/wav
29+
batch_size: int = 8
30+
data_loader_nj: int = 5
31+
new_datadir_suffix: str = "_anon"
32+
33+
@dataclass
34+
class Cmd(script_utils.ConfigParser):
35+
device: str = "cuda"
36+
ngpu: script_utils.ngpu = 1
37+
jobs_per_compute_device: int = 1 # number of jobs per gpus/cpus
38+
39+
40+
def update_progress_bar(progress, total):
41+
with tqdm(total=total) as pbar:
42+
while progress.value < total:
43+
pbar.n = progress.value
44+
pbar.refresh()
45+
time.sleep(0.5) # Adjust the sleep time as needed
46+
pbar.n = total
47+
pbar.refresh()
48+
49+
def compute_pipeline(cfg_cmd, cfg_pipeline, directory, wavscp, progress):
50+
import satools.bin.pipeline
51+
satools.bin.pipeline.process_data(directory, cfg_pipeline.target_selection_algorithm, wavscp, cfg_pipeline, progress)
52+
53+
54+
if __name__ == "__main__":
55+
parser = argparse.ArgumentParser(description=description)
56+
parser.add_argument("--config", default="configs/default", required=True)
57+
parser.add_argument("--directory", default="data/default", required=True)
58+
args = parser.parse_args()
59+
60+
logging.info("Reading config")
61+
cfg_parse = configparser.ConfigParser()
62+
cfg_parse.read(args.config)
63+
cfg_parse = script_utils.vartoml(cfg_parse)
64+
65+
cfg_cmd = Cmd().load_from_config(cfg_parse["cmd"])
66+
cfg_pipeline = Pipeline().load_from_config(cfg_parse["pipeline"])
67+
cfg_pipeline.device = cfg_cmd.device
68+
69+
wavscp = script_utils.read_wav_scp(os.path.join(args.directory, "wav.scp"))
70+
71+
wavscp_for_jobs = list(script_utils.split_dict(wavscp, len(cfg_cmd.ngpu) * cfg_cmd.jobs_per_compute_device))
72+
progress = Value('i', 0)
73+
74+
processes = []
75+
index = 0
76+
for gpu_id in cfg_cmd.ngpu:
77+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
78+
for job_id in range(cfg_cmd.jobs_per_compute_device):
79+
p = Process(target=compute_pipeline, args=(cfg_cmd, cfg_pipeline, args.directory, wavscp_for_jobs[index], progress))
80+
index += 1
81+
processes.append(p)
82+
p.start()
83+
84+
# Start a thread to update the progress bar
85+
progress_thread = Process(target=update_progress_bar, args=(progress, len(wavscp)))
86+
progress_thread.start()
87+
88+
for p in processes:
89+
p.join()
90+
if p.exitcode != 0:
91+
print(f"Process {p.pid} exited with code {p.exitcode}. Terminating.")
92+
for proc in processes:
93+
if proc.is_alive():
94+
proc.terminate()
95+
progress_thread.terminate()
96+
sys.exit(1)
97+
98+
progress_thread.terminate()
99+
logging.info('Done')

satools/satools/bin/pipeline.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#!/usr/bin/env python3.0
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
import shutil
6+
import multiprocessing
7+
from pathlib import Path
8+
import torchaudio
9+
import random
10+
import glob
11+
import logging
12+
13+
import torch
14+
15+
import satools.script_utils as script_utils
16+
from satools.infer_helper import load_model
17+
from satools.utils.kaldi import load_wav_from_scp
18+
19+
def copy_data_dir(dataset_path, output_path):
20+
# Copy utt2spk wav.scp and so on, but not the directories inside (may contains clear or anonymzied *.wav)
21+
os.makedirs(output_path, exist_ok=True)
22+
for p in glob.glob(str(Path(dataset_path) / '*'), recursive=False):
23+
if os.path.isfile(p):
24+
shutil.copy(p, output_path)
25+
26+
class Wav(): # for f0 extraction
27+
def __init__(self, w):
28+
self.wav = w
29+
30+
class Dataset(torch.utils.data.Dataset):
31+
def __init__(self, id_wavs, get_f0_func):
32+
self.all_wavs = list(id_wavs.values())
33+
self.all_keys = list(id_wavs.keys())
34+
self.get_f0_func = get_f0_func
35+
36+
def __len__(self):
37+
return len(self.all_wavs)
38+
39+
def __getitem__(self, index):
40+
audio, freq = load_wav_from_scp(str(self.all_wavs[index]))
41+
f0 = self.get_f0_func(Wav(audio))
42+
return {"utid": self.all_keys[index],
43+
"audio": audio,
44+
"f0": f0,
45+
"freq": freq}
46+
47+
def collate_fn(item_list):
48+
batch_size = len(item_list)
49+
50+
data_list_audio = [i['audio'] for i in item_list]
51+
lengths_tensor_audio = torch.tensor([i.shape[-1] for i in data_list_audio])
52+
max_len_audio = torch.max(lengths_tensor_audio).item()
53+
output_audio = torch.zeros([batch_size, max_len_audio])
54+
for i in range(batch_size):
55+
cur = data_list_audio[i]
56+
cur_len = data_list_audio[i].shape[-1]
57+
output_audio[i, :cur_len] = cur.squeeze()
58+
59+
data_list_f0 = [i['f0'] for i in item_list]
60+
lengths_tensor_f0 = torch.tensor([i.shape[-1] for i in data_list_f0])
61+
max_len_f0 = torch.max(lengths_tensor_f0).item()
62+
output_f0 = torch.zeros([batch_size, max_len_f0])
63+
for i in range(batch_size):
64+
cur = data_list_f0[i]
65+
cur_len = data_list_f0[i].shape[-1]
66+
output_f0[i, :cur_len] = cur.squeeze()
67+
68+
utids = [i['utid'] for i in item_list]
69+
freqs = [i['freq'] for i in item_list]
70+
return output_audio, output_f0, lengths_tensor_audio, utids, freqs
71+
72+
def process_data(dataset_path: str, target_selection_algorithm: str, wavscp: dict, settings: dict, progress):
73+
results_dir = settings.results_dir
74+
dataset_path = Path(str(dataset_path))
75+
output_path = Path(str(dataset_path) + settings.new_datadir_suffix)
76+
device = settings.device
77+
batch_size = settings.batch_size
78+
79+
copy_data_dir(dataset_path, output_path)
80+
results_dir = output_path / results_dir
81+
os.makedirs(results_dir, exist_ok = True)
82+
83+
wav_scp = dataset_path / 'wav.scp'
84+
utt2spk = dataset_path / 'utt2spk'
85+
wav_scp_out = output_path / 'wav.scp'
86+
87+
model = load_model(settings.model)
88+
model.to(device)
89+
model.eval()
90+
possible_targets = model.spk.copy() # For spk and utt target_selection_algorithm random choice
91+
92+
source_utt2spk = script_utils.read_wav_scp(utt2spk)
93+
out_spk2target = {} # For spk target_selection_algorithm
94+
95+
96+
@torch.no_grad()
97+
def process_wav(utid, freq, audio, f0, original_len):
98+
99+
freq = freq[0] # assume all freq = in same batch (and so dataset)
100+
audio = audio.to(device)
101+
102+
# Anonymize function
103+
model.set_f0(f0.to(device)) # CPU extracted by Dataloader (num_workers)
104+
# Batch select target spks from the available model list depending on target_selection_algorithm
105+
target_spks = []
106+
if target_selection_algorithm == "constant": # The best way/most secure to evaluate privacy when applied to all dataset (train included)
107+
target_constant_spkid = settings.target_constant_spkid # For constant target_selection_algorithm
108+
target_spks = [target_constant_spkid]*audio.shape[0]
109+
elif target_selection_algorithm == "bad_for_evaluation":
110+
# This target selection algorithm is bad for evaluation as it does
111+
# not generate suitable training data for the ASV eval training
112+
# procedure. Use it with caution.
113+
for ut in utid:
114+
source_spk = source_utt2spk[ut]
115+
if source_spk not in out_spk2target:
116+
out_spk2target[source_spk] = random.sample(possible_targets, 2)
117+
target_spks.append(random.choice(out_spk2target[source_spk]))
118+
elif target_selection_algorithm == "random_per_utt":
119+
target_spks = []
120+
for ut in utid:
121+
target_spks.append(random.choice(possible_targets))
122+
elif target_selection_algorithm == "random_per_spk_uniq":
123+
for ut in utid:
124+
source_spk = source_utt2spk[ut]
125+
if source_spk not in out_spk2target:
126+
out_spk2target[source_spk] = random.choice(possible_targets)
127+
# Remove target spk: size of possible source spk to anonymize == len(possible_targets) (==247) or you need to add spk target overlap)
128+
possible_targets.remove(out_spk2target[source_spk])
129+
target_spks.append(out_spk2target[source_spk])
130+
elif target_selection_algorithm == "random_per_spk":
131+
for ut in utid:
132+
source_spk = source_utt2spk[ut]
133+
if source_spk not in out_spk2target:
134+
out_spk2target[source_spk] = random.choice(possible_targets)
135+
target_spks.append(out_spk2target[source_spk])
136+
else:
137+
raise ValueError(f"{target_selection_algorithm} not implemented")
138+
# Batch conversion
139+
wav_conv = model.convert(audio, target=target_spks)
140+
wav_conv = wav_conv.cpu()
141+
142+
def parallel_write():
143+
for i in range(wav_conv.shape[0]):
144+
wav = wav_conv[i]
145+
if len(wav.shape) == 1:
146+
wav = wav.unsqueeze(0) # batch == 1 -> len(dst) % batch == 1
147+
wav = wav[:, :original_len[i]]
148+
# write to buffer
149+
u = utid[i]
150+
output_file = results_dir / f'{u}.wav'
151+
torchaudio.save(str(output_file), wav, freq, encoding='PCM_S', bits_per_sample=16)
152+
p = multiprocessing.Process(target=parallel_write, args=())
153+
p.start()
154+
return p
155+
156+
nj = settings.data_loader_nj
157+
nj = min(nj, 18)
158+
p = None
159+
160+
with open(wav_scp_out, 'wt', encoding='utf-8') as writer:
161+
filtered_wavs = {}
162+
for u, file in wavscp.items():
163+
output_file = results_dir / f'{u}.wav'
164+
filtered_wavs[u] = file
165+
166+
data_loader = torch.utils.data.DataLoader(Dataset(filtered_wavs, model.get_f0), batch_size=batch_size, num_workers=nj, collate_fn=collate_fn)
167+
for audio, f0, original_len, utid, freq in data_loader:
168+
p = process_wav(utid, freq, audio, f0, original_len)
169+
for u in utid:
170+
output_file = results_dir / f'{u}.wav'
171+
writer.writelines(f"{u} {output_file}\n")
172+
with progress.get_lock():
173+
progress.value += batch_size
174+
if device.startswith("cuda"):
175+
torch.cuda.empty_cache()
176+
# wait for last p to write the anonymized audios
177+
if p:
178+
p.join()

0 commit comments

Comments
 (0)