Skip to content

Commit 352a347

Browse files
committed
Apply isort and black reformatting
Signed-off-by: dayllon-balto <[email protected]>
1 parent 1784c8d commit 352a347

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

nemo/collections/asr/parts/utils/offline_clustering.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,13 @@ def kmeans_plusplus_torch(
150150

151151
centers = torch.zeros(n_clusters, n_features, dtype=X.dtype)
152152
center_id = torch.randint(0, n_samples, (1,)).long()
153-
indices = torch.full([n_clusters,], -1, dtype=torch.int)
153+
indices = torch.full(
154+
[
155+
n_clusters,
156+
],
157+
-1,
158+
dtype=torch.int,
159+
)
154160

155161
centers[0] = X[center_id].squeeze(0)
156162
indices[0] = center_id.squeeze(0)
@@ -551,12 +557,12 @@ def eigDecompose(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> T
551557
else:
552558
laplacian = laplacian.float().to(torch.device('cpu'))
553559

554-
#The next line crashed sometimes during diatization
555-
#Error: "linalg.eigh: Argument 8 has illegal value."
556-
#This happens with torch 2.6 but not 2.3
557-
#lambdas, diffusion_map = eigh(laplacian)
560+
# The next line crashed sometimes during diatization
561+
# Error: "linalg.eigh: Argument 8 has illegal value."
562+
# This happens with torch 2.6 but not 2.3
563+
# lambdas, diffusion_map = eigh(laplacian)
558564

559-
#The next fix ensure square, hermitian/symmetric inputs with correct layout
565+
# The next fix ensure square, hermitian/symmetric inputs with correct layout
560566
lambdas, diffusion_map = torch.linalg.eigh(
561567
laplacian.to(torch.float64).clone().contiguous(), # sane dtype & layout
562568
UPLO="L", # tell the backend which triangle is valid
@@ -997,8 +1003,12 @@ def forward(self) -> Tuple[torch.Tensor, torch.Tensor]:
9971003
est_spk_n_dict: Dict[int, torch.Tensor] = {}
9981004
self.p_value_list = self.getPvalueList()
9991005
p_volume = self.p_value_list.shape[0]
1000-
eig_ratio_list = torch.zeros(p_volume,)
1001-
est_num_of_spk_list = torch.zeros(p_volume,)
1006+
eig_ratio_list = torch.zeros(
1007+
p_volume,
1008+
)
1009+
est_num_of_spk_list = torch.zeros(
1010+
p_volume,
1011+
)
10021012

10031013
if self.parallelism:
10041014
futures: List[torch.jit.Future[torch.Tensor]] = []

0 commit comments

Comments
 (0)