Skip to content

Commit

Permalink
worker: add generate_speaker_clusters.py
Browse files Browse the repository at this point in the history
This manually performs the speaker identification step and writes the
speaker embeddings to a file. Can be used to debug speaker clustering / identification
  • Loading branch information
rroohhh committed Feb 26, 2024
1 parent cc8fecc commit c4a031d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
64 changes: 64 additions & 0 deletions worker/generate_speaker_clusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
import argparse
import sys
import tempfile
from asyncio import run
from zipfile import ZipFile

import automerge
import numpy as np
from transcribee_worker.identify_speakers import identify_speakers
from transcribee_worker.util import load_audio


async def main(args):
z = ZipFile(args.infile, mode="r", allowZip64=True)
for info in z.filelist:
if "__MACOSX" not in info.filename:
if info.filename.endswith(".automerge"):
automerge_doc = info.filename
elif info.filename.endswith(".mp3"):
media_file = info.filename

automerge_doc = automerge.load(z.read(automerge_doc))
media_file = z.read(media_file)
with tempfile.NamedTemporaryFile() as tmpfile:
tmpfile.write(media_file)
audio = load_audio(tmpfile.name)[0]

with automerge.transaction(automerge_doc, "dummy") as doc:
embeddings = await identify_speakers(
args.number_of_speakers, audio, doc, lambda *args, **kwargs: ...
)
np.savez(args.outfile, np.stack(embeddings))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="run speaker identification for each paragraph in a document"
)
parser.add_argument(
"--number_of_speakers",
default=None,
metavar="int",
type=int,
help="number of speakers",
)
parser.add_argument(
"infile",
default="export.zip",
type=argparse.FileType("rb"),
help="the transcribee export that should be considered",
)
parser.add_argument(
"outfile",
default=sys.stdout,
type=argparse.FileType("wb"),
help="file to write the output to",
)

args = parser.parse_args()

run(main(args))

args.outfile.close()
3 changes: 2 additions & 1 deletion worker/transcribee_worker/identify_speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def identify_speakers(
doc: Document,
progress_callback: ProgressCallbackType,
):
def work(_queue):
def work(queue):
logging.info("Running Speaker Identification")

if len(doc.children) == 0:
Expand Down Expand Up @@ -91,6 +91,7 @@ def time_to_sample(time: float | None):
wav_tensor = torch.tensor(wav[np.newaxis, :])
embedding = classifier.encode_batch(wav_tensor)
embeddings.append(embedding[0, 0].detach().numpy())
queue.submit(embedding[0, 0].detach().numpy())

progress_callback(
step="clustering speaker embeddings",
Expand Down

0 comments on commit c4a031d

Please sign in to comment.