Skip to content

Commit 9ab7de6

Browse files
David Ayllondayllon-balto
authored andcommitted
fix eigh input
1 parent 1469922 commit 9ab7de6

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

nemo/collections/asr/parts/utils/offline_clustering.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def getMultiScaleCosAffinityMatrix(
511511
512512
Returns:
513513
fused_sim_d (Tensor):
514-
An affinity matrix that is obtained by calculating the weighted sum of
514+
An affinity matrix that is obtained by calculating the weighted sum of
515515
the multiple affinity matrices from the different scales.
516516
"""
517517
multiscale_weights = torch.squeeze(multiscale_weights, dim=0).to(device)
@@ -550,7 +550,18 @@ def eigDecompose(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> T
550550
laplacian = laplacian.float().to(device)
551551
else:
552552
laplacian = laplacian.float().to(torch.device('cpu'))
553-
lambdas, diffusion_map = eigh(laplacian)
553+
554+
#The next line crashed sometimes during diatization
555+
#Error: "linalg.eigh: Argument 8 has illegal value."
556+
#This happens with torch 2.6 but not 2.3
557+
#lambdas, diffusion_map = eigh(laplacian)
558+
559+
#The next fix ensure square, hermitian/symmetric inputs with correct layout
560+
lambdas, diffusion_map = torch.linalg.eigh(
561+
laplacian.to(torch.float64).clone().contiguous(), # sane dtype & layout
562+
UPLO="L", # tell the backend which triangle is valid
563+
)
564+
554565
return lambdas, diffusion_map
555566

556567

@@ -1176,10 +1187,10 @@ def forward_unit_infer(
11761187
kmeans_random_trials: int = 1,
11771188
) -> torch.LongTensor:
11781189
"""
1179-
This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments
1180-
in the given input embeddings.
1181-
1182-
Args:
1190+
This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments
1191+
in the given input embeddings.
1192+
1193+
Args:
11831194
mat (Tensor):
11841195
Cosine similarity matrix (affinity matrix) calculated from the provided speaker embeddings.
11851196
oracle_num_speakers (int):
@@ -1202,8 +1213,8 @@ def forward_unit_infer(
12021213
This value should be optimized on a development set for best results.
12031214
By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold.
12041215
kmeans_random_trials (int):
1205-
The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1.
1206-
1216+
The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1.
1217+
12071218
Returns:
12081219
Y (LongTensor):
12091220
Speaker labels (clustering output) in integer format for the segments in the given input embeddings.

0 commit comments

Comments
 (0)