Skip to content

Commit 4e2ac4e

Browse files
committed
torch2.0, remove compile for now, round to times to 3 decimal
1 parent d2116b9 commit 4e2ac4e

File tree

6 files changed

+40
-34
lines changed

6 files changed

+40
-34
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
6161

6262

6363
<h2 align="left" id="setup">Setup ⚙️</h2>
64-
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
64+
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
6565

6666
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
6767

6868

69-
### 1. Create Python3.8 environment
69+
### 1. Create Python3.10 environment
7070

71-
`conda create --name whisperx python=3.8`
71+
`conda create --name whisperx python=3.10`
7272

7373
`conda activate whisperx`
7474

7575

76-
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
76+
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
7777

78-
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
78+
`pip3 install torch torchvision torchaudio`
7979

80-
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
80+
See other methods [here.](https://pytorch.org/get-started/locally/)
8181

8282
### 3. Install this repo
8383

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name="whisperx",
88
py_modules=["whisperx"],
9-
version="3.0.0",
9+
version="3.0.2",
1010
description="Time-Accurate Automatic Speech Recognition using Whisper.",
1111
readme="README.md",
1212
python_requires=">=3.8",

whisperx/alignment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ def align(
268268
start, end, score = None, None, None
269269
if cdx in clean_cdx:
270270
char_seg = char_segments[clean_cdx.index(cdx)]
271-
start = char_seg.start * ratio + t1
272-
end = char_seg.end * ratio + t1
271+
start = round(char_seg.start * ratio + t1, 3)
272+
end = round(char_seg.end * ratio + t1, 3)
273273
score = char_seg.score
274274

275275
char_segments_arr["char"].append(char)

whisperx/asr.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,6 @@ def _sanitize_parameters(self, **kwargs):
181181

182182
def preprocess(self, audio):
183183
audio = audio['inputs']
184-
if isinstance(audio, np.ndarray):
185-
audio = torch.from_numpy(audio)
186-
187184
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
188185
return {'inputs': features}
189186

@@ -256,7 +253,7 @@ def data(audio, segments):
256253
def detect_language(self, audio: np.ndarray):
257254
if audio.shape[0] < N_SAMPLES:
258255
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
259-
segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
256+
segment = log_mel_spectrogram(audio[: N_SAMPLES],
260257
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
261258
encoder_output = self.model.encode(segment)
262259
results = self.model.model.detect_language(encoder_output)

whisperx/audio.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
2323
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
2424

25-
with np.load(
26-
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
27-
) as f:
28-
MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])
29-
30-
3125

3226
def load_audio(file: str, sr: int = SAMPLE_RATE):
3327
"""
@@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
8579
return array
8680

8781

88-
@torch.compile(fullgraph=True)
82+
@lru_cache(maxsize=None)
83+
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
84+
"""
85+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
86+
Allows decoupling librosa dependency; saved using:
87+
88+
np.savez_compressed(
89+
"mel_filters.npz",
90+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
91+
)
92+
"""
93+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
94+
with np.load(
95+
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
96+
) as f:
97+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
98+
99+
89100
def log_mel_spectrogram(
90-
audio: torch.Tensor,
101+
audio: Union[str, np.ndarray, torch.Tensor],
102+
n_mels: int = N_MELS,
91103
padding: int = 0,
92104
device: Optional[Union[str, torch.device]] = None,
93105
):
@@ -96,7 +108,7 @@ def log_mel_spectrogram(
96108
97109
Parameters
98110
----------
99-
audio: torch.Tensor, shape = (*)
111+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
100112
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
101113
102114
n_mels: int
@@ -113,19 +125,21 @@ def log_mel_spectrogram(
113125
torch.Tensor, shape = (80, n_frames)
114126
A Tensor that contains the Mel spectrogram
115127
"""
116-
global MEL_FILTERS
128+
if not torch.is_tensor(audio):
129+
if isinstance(audio, str):
130+
audio = load_audio(audio)
131+
audio = torch.from_numpy(audio)
117132

118133
if device is not None:
119134
audio = audio.to(device)
120135
if padding > 0:
121136
audio = F.pad(audio, (0, padding))
122137
window = torch.hann_window(N_FFT).to(audio.device)
123-
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False)
124-
# Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors
125-
magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
138+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
139+
magnitudes = stft[..., :-1].abs() ** 2
126140

127-
MEL_FILTERS = MEL_FILTERS.to(audio.device)
128-
mel_spec = MEL_FILTERS @ magnitudes
141+
filters = mel_filters(audio.device, n_mels)
142+
mel_spec = filters @ magnitudes
129143

130144
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
131145
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)

whisperx/transcribe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def cli():
7272

7373
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
7474
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
75-
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
7675
# fmt: on
7776

7877
args = parser.parse_args().__dict__
@@ -86,10 +85,6 @@ def cli():
8685
# model_flush: bool = args.pop("model_flush")
8786
os.makedirs(output_dir, exist_ok=True)
8887

89-
tmp_dir: str = args.pop("tmp_dir")
90-
if tmp_dir is not None:
91-
os.makedirs(tmp_dir, exist_ok=True)
92-
9388
align_model: str = args.pop("align_model")
9489
interpolate_method: str = args.pop("interpolate_method")
9590
no_align: bool = args.pop("no_align")
@@ -195,7 +190,7 @@ def cli():
195190
tmp_results = results
196191
print(">>Performing diarization...")
197192
results = []
198-
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
193+
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
199194
for result, input_audio_path in tmp_results:
200195
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
201196
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])

0 commit comments

Comments
 (0)