diff --git a/rastermap/rastermap.py b/rastermap/rastermap.py index 9deb25f..e7a4b12 100644 --- a/rastermap/rastermap.py +++ b/rastermap/rastermap.py @@ -339,8 +339,8 @@ def fit(self, data=None, Usv=None, Vsv=None, U_nodes=None, itrain=None, self.sv = np.nansum((self.Usv**2), axis=0)**0.5 if not hasattr(self, "Vsv"): if Vsv is None: - U = self.Usv.copy() / self.sv - self.Vsv = X.T @ U + U = self.Usv[igood].copy() / self.sv + self.Vsv = X[igood].T @ U elif Vsv is not None: self.Vsv = Vsv @@ -491,7 +491,7 @@ def fit(self, data=None, Usv=None, Vsv=None, U_nodes=None, itrain=None, if (bin_size==0 or n_samples < bin_size or (bin_size == 50 and n_samples < 1000)): bin_size = max(1, n_samples // 500) - self.X_embedding = zscore(bin1d(X[igood][self.isort], bin_size, axis=0), axis=1) + self.X_embedding = zscore(bin1d(X[self.isort], bin_size, axis=0), axis=1) rmap_logger.info(f"rastermap complete, time {time.time() - t0:0.2f}sec")