diff --git a/worker/generate_speaker_clusters.py b/worker/generate_speaker_clusters.py new file mode 100644 index 00000000..c2d0630c --- /dev/null +++ b/worker/generate_speaker_clusters.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import argparse +import sys +import tempfile +from asyncio import run +from zipfile import ZipFile + +import automerge +import numpy as np +from transcribee_worker.identify_speakers import identify_speakers +from transcribee_worker.util import load_audio + + +async def main(args): + z = ZipFile(args.infile, mode="r", allowZip64=True) + for info in z.filelist: + if "__MACOSX" not in info.filename: + if info.filename.endswith(".automerge"): + automerge_doc = info.filename + elif info.filename.endswith(".mp3"): + media_file = info.filename + + automerge_doc = automerge.load(z.read(automerge_doc)) + media_file = z.read(media_file) + with tempfile.NamedTemporaryFile() as tmpfile: + tmpfile.write(media_file) + audio = load_audio(tmpfile.name)[0] + + with automerge.transaction(automerge_doc, "dummy") as doc: + embeddings = await identify_speakers( + args.number_of_speakers, audio, doc, lambda *args, **kwargs: ... + ) + np.savez(args.outfile, np.stack(embeddings)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="run speaker identification for each paragraph in a document" + ) + parser.add_argument( + "--number_of_speakers", + default=None, + metavar="int", + type=int, + help="number of speakers", + ) + parser.add_argument( + "infile", + default="export.zip", + type=argparse.FileType("rb"), + help="the transcribee export that should be considered", + ) + parser.add_argument( + "outfile", + default=sys.stdout, + type=argparse.FileType("wb"), + help="file to write the output to", + ) + + args = parser.parse_args() + + run(main(args)) + + args.outfile.close() diff --git a/worker/transcribee_worker/identify_speakers.py b/worker/transcribee_worker/identify_speakers.py index b2bc2e4e..7f855d8f 100644 --- a/worker/transcribee_worker/identify_speakers.py +++ b/worker/transcribee_worker/identify_speakers.py @@ -41,7 +41,7 @@ async def identify_speakers( doc: Document, progress_callback: ProgressCallbackType, ): - def work(_queue): + def work(queue): logging.info("Running Speaker Identification") if len(doc.children) == 0: @@ -91,6 +91,7 @@ def time_to_sample(time: float | None): wav_tensor = torch.tensor(wav[np.newaxis, :]) embedding = classifier.encode_batch(wav_tensor) embeddings.append(embedding[0, 0].detach().numpy()) + queue.submit(embedding[0, 0].detach().numpy()) progress_callback( step="clustering speaker embeddings",