diff --git a/whisperx/asr.py b/whisperx/asr.py index 0ccaf92ba..0181fc0bb 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -74,11 +74,13 @@ def decode_batch(tokens: List[List[int]]) -> str: return text - def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + def encode(self, features: Union[np.ndarray, torch.Tensor]) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 # unsqueeze if batch size = 1 + if isinstance(features, torch.Tensor): + features = features.cpu().numpy() if len(features.shape) == 2: features = np.expand_dims(features, 0) features = faster_whisper.transcribe.get_ctranslate2_storage(features) @@ -171,11 +173,14 @@ def stack(items): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False + self, audio: Union[str, np.ndarray, torch.Tensor], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE) @@ -183,7 +188,7 @@ def data(audio, segments): # print(f2-f1) yield {'inputs': audio[f1:f2]} - vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = self.vad_model({"waveform": audio.unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( vad_segments, chunk_size, @@ -203,7 +208,7 @@ def data(audio, segments): self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language) - + if self.suppress_numerals: previous_suppress_tokens = self.options.suppress_tokens numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) @@ -242,7 +247,7 @@ def data(audio, segments): return {"segments": segments, "language": language} - def detect_language(self, audio: np.ndarray): + def detect_language(self, audio: Union[np.ndarray, torch.Tensor]): if audio.shape[0] < N_SAMPLES: print("Warning: audio is shorter than 30s, language detection may be inaccurate.") model_n_mels = self.model.feat_kwargs.get("feature_size") diff --git a/whisperx/diarize.py b/whisperx/diarize.py index c327c9320..ba76ce708 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -18,11 +18,17 @@ def __init__( device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) - def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): + def __call__(self, audio: Union[str, np.ndarray, torch.Tensor], num_speakers=None, min_speakers=None, max_speakers=None): if isinstance(audio, str): audio = load_audio(audio) + + audio = audio[None, :] + + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + audio_data = { - 'waveform': torch.from_numpy(audio[None, :]), + 'waveform': audio, 'sample_rate': SAMPLE_RATE } segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) @@ -47,7 +53,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): # sum over speakers speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] seg["speaker"] = speaker - + # assign speaker to words if 'words' in seg: for word in seg['words']: @@ -63,8 +69,8 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): # sum over speakers speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] word["speaker"] = speaker - - return transcript_result + + return transcript_result class Segment: