Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions nemo/collections/asr/parts/utils/offline_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def kmeans_plusplus_torch(

centers = torch.zeros(n_clusters, n_features, dtype=X.dtype)
center_id = torch.randint(0, n_samples, (1,)).long()
indices = torch.full([n_clusters,], -1, dtype=torch.int)
indices = torch.full(
[
n_clusters,
],
-1,
dtype=torch.int,
)

centers[0] = X[center_id].squeeze(0)
indices[0] = center_id.squeeze(0)
Expand Down Expand Up @@ -511,7 +517,7 @@ def getMultiScaleCosAffinityMatrix(

Returns:
fused_sim_d (Tensor):
An affinity matrix that is obtained by calculating the weighted sum of
An affinity matrix that is obtained by calculating the weighted sum of
the multiple affinity matrices from the different scales.
"""
multiscale_weights = torch.squeeze(multiscale_weights, dim=0).to(device)
Expand Down Expand Up @@ -550,7 +556,18 @@ def eigDecompose(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> T
laplacian = laplacian.float().to(device)
else:
laplacian = laplacian.float().to(torch.device('cpu'))
lambdas, diffusion_map = eigh(laplacian)

# The next line crashed sometimes during diatization
# Error: "linalg.eigh: Argument 8 has illegal value."
# This happens with torch 2.6 but not 2.3
# lambdas, diffusion_map = eigh(laplacian)

# The next fix ensure square, hermitian/symmetric inputs with correct layout
lambdas, diffusion_map = torch.linalg.eigh(
laplacian.to(torch.float64).clone().contiguous(), # sane dtype & layout
UPLO="L", # tell the backend which triangle is valid
)

return lambdas, diffusion_map


Expand Down Expand Up @@ -986,8 +1003,12 @@ def forward(self) -> Tuple[torch.Tensor, torch.Tensor]:
est_spk_n_dict: Dict[int, torch.Tensor] = {}
self.p_value_list = self.getPvalueList()
p_volume = self.p_value_list.shape[0]
eig_ratio_list = torch.zeros(p_volume,)
est_num_of_spk_list = torch.zeros(p_volume,)
eig_ratio_list = torch.zeros(
p_volume,
)
est_num_of_spk_list = torch.zeros(
p_volume,
)

if self.parallelism:
futures: List[torch.jit.Future[torch.Tensor]] = []
Expand Down Expand Up @@ -1176,10 +1197,10 @@ def forward_unit_infer(
kmeans_random_trials: int = 1,
) -> torch.LongTensor:
"""
This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments
in the given input embeddings.
Args:
This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments
in the given input embeddings.

Args:
mat (Tensor):
Cosine similarity matrix (affinity matrix) calculated from the provided speaker embeddings.
oracle_num_speakers (int):
Expand All @@ -1202,8 +1223,8 @@ def forward_unit_infer(
This value should be optimized on a development set for best results.
By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold.
kmeans_random_trials (int):
The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1.
The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1.

Returns:
Y (LongTensor):
Speaker labels (clustering output) in integer format for the segments in the given input embeddings.
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements_asr.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ lhotse>=1.31.1
# Align with upstream PyTorch requirements
librosa>=0.10.1
marshmallow
megatron-core
optuna
packaging
pyannote.core
Expand Down