forked from victor-upmeet/whisperx-replicate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
304 lines (225 loc) · 11.5 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from cog import BasePredictor, Input, Path, BaseModel
from pydub import AudioSegment
from typing import Any
from whisperx.audio import N_SAMPLES, log_mel_spectrogram
import gc
import math
import os
import shutil
import whisperx
import tempfile
import time
import torch
from dotenv import load_dotenv
import os
import requests
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
device = "cuda"
whisper_arch = "./models/faster-whisper-large-v3"
chunk_size = 15
# Load environment variables from .env file
load_dotenv()
# Read an environment variable
huggingface_access_token = os.getenv('hf_token')
class Output(BaseModel):
segments: Any
detected_language: str
class Predictor(BasePredictor):
def setup(self):
self.diarize_model = whisperx.DiarizationPipeline(
model_name='pyannote/speaker-diarization-3.1',
use_auth_token=huggingface_access_token, device=device)
source_folder = './models/vad'
destination_folder = '../root/.cache/torch'
file_name = 'whisperx-vad-segmentation.bin'
os.makedirs(destination_folder, exist_ok=True)
source_file_path = os.path.join(source_folder, file_name)
if os.path.exists(source_file_path):
destination_file_path = os.path.join(destination_folder, file_name)
if not os.path.exists(destination_file_path):
shutil.copy(source_file_path, destination_folder)
def predict(
self,
audio_file: Path = Input(description="Audio file"),
language: str = Input(
description="ISO code of the language spoken in the audio, specify None to perform language detection",
default=None),
language_detection_min_prob: float = Input(
description="If language is not specified, then the language will be detected recursively on different "
"parts of the file until it reaches the given probability",
default=0
),
language_detection_max_tries: int = Input(
description="If language is not specified, then the language will be detected following the logic of "
"language_detection_min_prob parameter, but will stop after the given max retries. If max "
"retries is reached, the most probable language is kept.",
default=5
),
initial_prompt: str = Input(
description="Optional text to provide as a prompt for the first window",
default=None),
batch_size: int = Input(
description="Parallelization of input audio transcription",
default=64),
temperature: float = Input(
description="Temperature to use for sampling",
default=0),
vad_onset: float = Input(
description="VAD onset",
default=0.500),
vad_offset: float = Input(
description="VAD offset",
default=0.363),
align_output: bool = Input(
description="Aligns whisper output to get accurate word-level timestamps",
default=False),
diarization: bool = Input(
description="Assign speaker ID labels",
default=False),
min_speakers: int = Input(
description="Minimum number of speakers if diarization is activated (leave blank if unknown)",
default=None),
max_speakers: int = Input(
description="Maximum number of speakers if diarization is activated (leave blank if unknown)",
default=None),
debug: bool = Input(
description="Print out compute/inference times and memory usage information",
default=False)
) -> Output:
with torch.inference_mode():
asr_options = {
"temperatures": [temperature],
"initial_prompt": initial_prompt
}
vad_options = {
"vad_onset": vad_onset,
"vad_offset": vad_offset
}
audio_duration = get_audio_duration(audio_file)
if language is None and language_detection_min_prob > 0 and audio_duration > 30000:
segments_duration_ms = 30000
language_detection_max_tries = min(
language_detection_max_tries,
math.floor(audio_duration / segments_duration_ms)
)
segments_starts = distribute_segments_equally(audio_duration, segments_duration_ms,
language_detection_max_tries)
print("Detecting languages on segments starting at " + ', '.join(map(str, segments_starts)))
detected_language_details = detect_language(audio_file, segments_starts, language_detection_min_prob,
language_detection_max_tries, asr_options, vad_options)
detected_language_code = detected_language_details["language"]
detected_language_prob = detected_language_details["probability"]
detected_language_iterations = detected_language_details["iterations"]
print(f"Detected language {detected_language_code} ({detected_language_prob:.2f}) after "
f"{detected_language_iterations} iterations.")
language = detected_language_details["language"]
start_time = time.time_ns() / 1e6
model = whisperx.load_model(whisper_arch, device, compute_type=compute_type, language=language,
asr_options=asr_options, vad_options=vad_options)
if debug:
elapsed_time = time.time_ns() / 1e6 - start_time
print(f"Duration to load model: {elapsed_time:.2f} ms")
start_time = time.time_ns() / 1e6
audio = whisperx.load_audio(audio_file)
if debug:
elapsed_time = time.time_ns() / 1e6 - start_time
print(f"Duration to load audio: {elapsed_time:.2f} ms")
start_time = time.time_ns() / 1e6
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size)
detected_language = result["language"]
if debug:
elapsed_time = time.time_ns() / 1e6 - start_time
print(f"Duration to transcribe: {elapsed_time:.2f} ms")
gc.collect()
torch.cuda.empty_cache()
del model
if align_output:
if detected_language in whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH or detected_language in whisperx.alignment.DEFAULT_ALIGN_MODELS_HF:
result = align(audio, result, debug)
else:
print(f"Cannot align output as language {detected_language} is not supported for alignment")
if diarization:
result = self.diarize(audio, result, debug, min_speakers, max_speakers)
if debug:
print(f"max gpu memory allocated over runtime: {torch.cuda.max_memory_reserved() / (1024 ** 3):.2f} GB")
return Output(
segments=result["segments"],
detected_language=detected_language
)
def diarize(self, audio, result, debug, min_speakers, max_speakers):
start_time = time.time_ns() / 1e6
diarize_segments = self.diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
if debug:
elapsed_time = time.time_ns() / 1e6 - start_time
print(f"Duration to diarize segments: {elapsed_time:.2f} ms")
gc.collect()
torch.cuda.empty_cache()
return result
def get_audio_duration(file_path):
return len(AudioSegment.from_file(file_path))
def detect_language(full_audio_file_path, segments_starts, language_detection_min_prob,
language_detection_max_tries, asr_options, vad_options, iteration=1):
model = whisperx.load_model(whisper_arch, device, compute_type=compute_type, asr_options=asr_options,
vad_options=vad_options)
start_ms = segments_starts[iteration - 1]
audio_segment_file_path = extract_audio_segment(full_audio_file_path, start_ms, 30000)
audio = whisperx.load_audio(audio_segment_file_path)
model_n_mels = model.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = model.model.encode(segment)
results = model.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Iteration {iteration} - Detected language: {language} ({language_probability:.2f})")
audio_segment_file_path.unlink()
gc.collect()
torch.cuda.empty_cache()
del model
detected_language = {
"language": language,
"probability": language_probability,
"iterations": iteration
}
if language_probability >= language_detection_min_prob or iteration >= language_detection_max_tries:
return detected_language
next_iteration_detected_language = detect_language(full_audio_file_path, segments_starts,
language_detection_min_prob, language_detection_max_tries,
asr_options, vad_options, iteration + 1)
if next_iteration_detected_language["probability"] > detected_language["probability"]:
return next_iteration_detected_language
return detected_language
def extract_audio_segment(input_file_path, start_time_ms, duration_ms):
input_file_path = Path(input_file_path) if not isinstance(input_file_path, Path) else input_file_path
audio = AudioSegment.from_file(input_file_path)
end_time_ms = start_time_ms + duration_ms
extracted_segment = audio[start_time_ms:end_time_ms]
file_extension = input_file_path.suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
temp_file_path = Path(temp_file.name)
extracted_segment.export(temp_file_path, format=file_extension.lstrip('.'))
return temp_file_path
def distribute_segments_equally(total_duration, segments_duration, iterations):
available_duration = total_duration - segments_duration
if iterations > 1:
spacing = available_duration // (iterations - 1)
else:
spacing = 0
start_times = [i * spacing for i in range(iterations)]
if iterations > 1:
start_times[-1] = total_duration - segments_duration
return start_times
def align(audio, result, debug):
start_time = time.time_ns() / 1e6
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device,
return_char_alignments=False)
if debug:
elapsed_time = time.time_ns() / 1e6 - start_time
print(f"Duration to align output: {elapsed_time:.2f} ms")
gc.collect()
torch.cuda.empty_cache()
del model_a
return result