-
Notifications
You must be signed in to change notification settings - Fork 0
/
diarize.py
82 lines (64 loc) · 2.46 KB
/
diarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# For Mac M processors
# pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# https://developer.apple.com/metal/pytorch/
# https://developer.apple.com/metal/mps/
# https://huggingface.co/pyannote/speaker-diarization-3.1
import yaml
import torch
import sys
import os
from pyannote.audio import Pipeline
def test_torch():
""" Use torch to test if MPS is available. """
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
x = torch.ones(1, device=mps_device)
print (x)
else:
print ("MPS device not found.")
def get_huggingface_token(yaml_file_path="../apis.yaml"):
"""
# ample_apis.yaml
api:
huggingface_token: "your_huggingface_token_here"
"""
with open(yaml_file_path, 'r') as file:
data = yaml.safe_load(file)
# Accessing the nested 'huggingface_token' under 'api'
return data['api']['huggingface_token'] if 'api' in data and 'huggingface_token' in data['api'] else None
def diarize(inputfile,outputfile):
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=get_huggingface_token(),
)
# send pipeline to GPU (when available)
# pipeline.to(torch.device("cuda"))
# print("Using CUDA GPU")
# send pipeline to MPS (when available)
pipeline.to(torch.device("mps"))
print("Using MPS GPU (Apple M processors)")
# apply pretrained pipeline
diarization = pipeline(inputfile)
# print the result
with open(outputfile, "w") as f:
for turn, _, speaker in diarization.itertracks(yield_label=True):
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}", file=f)
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")
def main():
if len(sys.argv) < 3:
print("Usage: python3 diarize.py inputfile outputfile")
sys.exit(1)
input_audio = sys.argv[1]
outputfile = sys.argv[2]
# Check if the input file exists and is a .wav file
if not os.path.exists(input_audio):
print(f"Error: The file {input_audio} does not exist.")
sys.exit(1)
if not input_audio.lower().endswith('.wav'):
print("Error: The input file is not a .wav file.")
sys.exit(1)
print("Diarizing file {input_audio} to {outputfile}")
diarize(input_audio,outputfile)
print("*************** Done")
if __name__ == "__main__":
main()