Skip to content

Commit

Permalink
Accept torch.Tensor as input
Browse files Browse the repository at this point in the history
  • Loading branch information
dkurt committed May 16, 2024
1 parent f2da2f8 commit 3be4987
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
18 changes: 12 additions & 6 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ def decode_batch(tokens: List[List[int]]) -> str:

return text

def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
def encode(self, features: Union[np.ndarray, torch.Tensor]) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
if isinstance(features, np.ndarray):
features = np.expand_dims(features, 0)
else:
features = features.unsqueeze(0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)

return self.model.encode(features, to_cpu=to_cpu)
Expand Down Expand Up @@ -171,19 +174,22 @@ def stack(items):
return final_iterator

def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
self, audio: Union[str, np.ndarray, torch.Tensor], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)

if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)

def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}

vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = self.vad_model({"waveform": audio.unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
Expand All @@ -203,7 +209,7 @@ def data(audio, segments):
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)

if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
Expand Down Expand Up @@ -242,7 +248,7 @@ def data(audio, segments):
return {"segments": segments, "language": language}


def detect_language(self, audio: np.ndarray):
def detect_language(self, audio: Union[np.ndarray, torch.Tensor]):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
Expand Down
16 changes: 11 additions & 5 deletions whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ def __init__(
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)

def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
def __call__(self, audio: Union[str, np.ndarray, torch.Tensor], num_speakers=None, min_speakers=None, max_speakers=None):
if isinstance(audio, str):
audio = load_audio(audio)

audio = audio[None, :]

if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)

audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'waveform': audio,
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
Expand All @@ -47,7 +53,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
seg["speaker"] = speaker

# assign speaker to words
if 'words' in seg:
for word in seg['words']:
Expand All @@ -63,8 +69,8 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
return transcript_result

return transcript_result


class Segment:
Expand Down

0 comments on commit 3be4987

Please sign in to comment.