Skip to content

Commit

Permalink
Merge pull request #986 from bfs18/main
Browse files Browse the repository at this point in the history
support timestamp for numbers.
  • Loading branch information
m-bain authored Jan 14, 2025
2 parents 1027367 + ffbc736 commit 12604a4
Showing 1 changed file with 186 additions and 44 deletions.
230 changes: 186 additions & 44 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Forced Alignment with Whisper
C. Max Bain
"""
import math

from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
Expand Down Expand Up @@ -171,10 +172,17 @@ def align(
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)

clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx)


Expand Down Expand Up @@ -222,7 +230,7 @@ def align(
continue

text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
tokens = [model_dictionary.get(c, -1) for c in text_clean]

f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
Expand Down Expand Up @@ -255,7 +263,8 @@ def align(
blank_id = code

trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)

if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
Expand All @@ -264,7 +273,7 @@ def align(

char_segments = merge_repeats(path, text_clean)

duration = t2 -t1
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)

# assign timestamps to aligned characters
Expand Down Expand Up @@ -371,70 +380,203 @@ def align(
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
"""


def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)

# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1:, 0] = float("inf")

for t in range(num_frame):
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
)
return trellis


def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)

# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens

# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)

# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index

# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()

# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)

return result


@dataclass
class Point:
token_index: int
time_index: int
score: float


def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()

path = []
for t in range(t_start, 0, -1):
t, j = trellis.size(0) - 1, trellis.size(1) - 1

path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0

# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))

# 3. Update the token
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]

# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change

# Update position
t -= 1
if changed > stayed:
j -= 1
if j == 0:
break
else:
# failed
return None

# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))

# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1

return path[::-1]



@dataclass
class Path:
points: List[Point]
score: float


@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history


def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1

init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)

beams = [init_state]

while beams and beams[0].token_index > 0:
next_beams = []

for beam in beams:
t, j = beam.time_index, beam.token_index

if t <= 0:
continue

p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]

stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')

# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))

# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))

# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]

if not beams:
break

if not beams:
return None

best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1

return best_beam.path[::-1]


# Merge the labels
@dataclass
class Segment:
Expand Down

0 comments on commit 12604a4

Please sign in to comment.