Skip to content

Commit

Permalink
Merge branch 'add_embeddings'
Browse files Browse the repository at this point in the history
  • Loading branch information
cnesp committed Jan 14, 2025
2 parents 1027367 + 6823d96 commit 743c4a5
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 56 deletions.
98 changes: 64 additions & 34 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens


class WhisperModel(faster_whisper.WhisperModel):
'''
"""
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
"""

def generate_segment_batched(
self,
Expand Down Expand Up @@ -59,15 +60,15 @@ def generate_segment_batched(
)

result = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
encoder_output,
[prompt] * batch_size,
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)

tokens_batch = [x.sequences_ids[0] for x in result]

Expand All @@ -93,10 +94,12 @@ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:

return self.model.encode(features, to_cpu=to_cpu)


class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""

# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
Expand All @@ -121,7 +124,9 @@ def __init__(
self.suppress_numerals = suppress_numerals
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self._preprocess_params, self._forward_params, self._postprocess_params = (
self._sanitize_parameters(**kwargs)
)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
Expand All @@ -147,18 +152,20 @@ def _sanitize_parameters(self, **kwargs):
return preprocess_kwargs, {}, {}

def preprocess(self, audio):
audio = audio['inputs']
audio = audio["inputs"]
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
return {"inputs": features}

def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
outputs = self.model.generate_segment_batched(
model_inputs["inputs"], self.tokenizer, self.options
)
return {"text": outputs}

def postprocess(self, model_outputs):
return model_outputs
Expand All @@ -178,10 +185,17 @@ def get_iterator(
# TODO hack by collating feature_extractor and image_processor

def stack(items):
return {'inputs': torch.stack([x['inputs'] for x in items])}
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return {"inputs": torch.stack([x["inputs"] for x in items])}

dataloader = torch.utils.data.DataLoader(
dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack
)
model_iterator = PipelineIterator(
dataloader, self.forward, forward_params, loader_batch_size=batch_size
)
final_iterator = PipelineIterator(
model_iterator, self.postprocess, postprocess_params
)
return final_iterator

def transcribe(
Expand All @@ -201,8 +215,8 @@ def transcribe(

def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
f1 = int(seg["start"] * SAMPLE_RATE)
f2 = int(seg["end"] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}

Expand Down Expand Up @@ -253,21 +267,29 @@ def data(audio, segments):
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
for idx, out in enumerate(
self.__call__(
data(audio, vad_segments),
batch_size=batch_size,
num_workers=num_workers,
)
):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
percent_complete = (
base_progress / 2 if combined_progress else base_progress
)
print(f"Progress: {percent_complete:.2f}%...")
text = out['text']
text = out["text"]
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
"start": round(vad_segments[idx]["start"], 3),
"end": round(vad_segments[idx]["end"], 3),
}
)

Expand All @@ -283,16 +305,22 @@ def data(audio, segments):

def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
print(
"Warning: audio is shorter than 30s, language detection may be inaccurate."
)
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
segment = log_mel_spectrogram(
audio[:N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0],
)
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
print(
f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..."
)
return language


Expand Down Expand Up @@ -341,10 +369,12 @@ def load_model(
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
print(
"No language specified, language will be first be detected for each audio file (increases inference time)."
)
tokenizer = None

default_asr_options = {
default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
Expand Down
96 changes: 74 additions & 22 deletions whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,60 @@ def __init__(
):
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
self.model = Pipeline.from_pretrained(
model_name, use_auth_token=use_auth_token
).to(device)

def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
return_embeddings: Optional[bool] = False,
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
"waveform": torch.from_numpy(audio[None, :]),
"sample_rate": SAMPLE_RATE,
}
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
return diarize_df
if return_embeddings:
segments, embeddings = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_embeddings,
)
else:
segments = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
diarize_df = pd.DataFrame(
segments.itertracks(yield_label=True),
columns=["segment", "label", "speaker"],
)

if return_embeddings:
embeddings_list = []
speaker_list = []
for s, speaker in enumerate(segments.labels()):
embeddings_list.append(embeddings[s])
speaker_list.append(speaker)
embeddings_df = pd.DataFrame(
data={"speaker": speaker_list, "embeddings": embeddings_list}
)
diarize_df["start"] = diarize_df["segment"].apply(lambda x: x.start)
diarize_df["end"] = diarize_df["segment"].apply(lambda x: x.end)

if return_embeddings:
return diarize_df, embeddings_df
else: # return diarize_df only
return diarize_df


def assign_word_speakers(
Expand All @@ -47,35 +81,53 @@ def assign_word_speakers(
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
diarize_df["intersection"] = np.minimum(
diarize_df["end"], seg["end"]
) - np.maximum(diarize_df["start"], seg["start"])
diarize_df["union"] = np.maximum(diarize_df["end"], seg["end"]) - np.minimum(
diarize_df["start"], seg["start"]
)
# remove no hit, otherwise we look for closest (even negative intersection...)
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
dia_tmp = diarize_df[diarize_df["intersection"] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
speaker = (
dia_tmp.groupby("speaker")["intersection"]
.sum()
.sort_values(ascending=False)
.index[0]
)
seg["speaker"] = speaker

# assign speaker to words
if 'words' in seg:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
if "words" in seg:
for word in seg["words"]:
if "start" in word:
diarize_df["intersection"] = np.minimum(
diarize_df["end"], word["end"]
) - np.maximum(diarize_df["start"], word["start"])
diarize_df["union"] = np.maximum(
diarize_df["end"], word["end"]
) - np.minimum(diarize_df["start"], word["start"])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
dia_tmp = diarize_df[diarize_df["intersection"] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
speaker = (
dia_tmp.groupby("speaker")["intersection"]
.sum()
.sort_values(ascending=False)
.index[0]
)
word["speaker"] = speaker
return transcript_result

return transcript_result


class Segment:
Expand Down

0 comments on commit 743c4a5

Please sign in to comment.