From 88f7575e521b8795cf70d613f07b81b831ec16b0 Mon Sep 17 00:00:00 2001 From: Eremin Egor Date: Wed, 9 Oct 2024 16:04:25 +0300 Subject: [PATCH 1/5] Add caching of ranks for computing RSA with Spearman --- mne_rsa/rsa.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index a234e11..cdac128 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -109,7 +109,7 @@ def _partial_correlation(rdm_data, rdm_model, masks=None, type="pearson"): return -R_partial[0, 1:] -def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False): +def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False, cache_ranks=True): """Generate RSA values between data and model RDMs. Will yield RSA scores for each data RDM. @@ -156,30 +156,44 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False): return_array = False rdm_model = [_ensure_condensed(rdm_model, "rdm_model")] + if ignore_nan: + if cache_ranks and (rsa_metric == "spearman"): + raise ValueError("ignore_nan and cache_ranks is not yet supported together") masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) + if cache_ranks and (rsa_metric == "spearman"): + # Precompute ranks for Spearman + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] + for rdm_data in rdm_data_gen: rdm_data = _ensure_condensed(rdm_data, "rdm_data") if ignore_nan: data_mask = ~np.isnan(rdm_data) masks = [m & data_mask for m in masks] - rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks) + rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks) if return_array: yield np.asarray(rsa_vals) else: yield rsa_vals[0] -def _rsa_single_rdm(rdm_data, rdm_model, metric, masks): +def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks): """Compute RSA between a single data RDM and one or more model RDMs.""" if metric == "spearman": - rsa_vals = [ - stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0] - for rdm_model_, mask in zip(rdm_model, masks) - ] + if cache_ranks: + rdm_data = stats.rankdata(rdm_data) + rsa_vals = [ + np.corrcoef(rdm_data, rdm_model_[mask])[0,1] + for rdm_model_, mask in zip(rdm_model, masks) + ] + else: + rsa_vals = [ + stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0] + for rdm_model_, mask in zip(rdm_model, masks) + ] elif metric == "pearson": rsa_vals = [ stats.pearsonr(rdm_data[mask], rdm_model_[mask])[0] @@ -312,6 +326,7 @@ def rsa_array( y=None, n_folds=1, n_jobs=1, + cache_ranks=True, verbose=False, ): """Perform RSA on an array of data, possibly in a searchlight pattern. @@ -419,6 +434,8 @@ def rsa_array( rdm_model = [_ensure_condensed(rdm_model, "rdm_model")] if ignore_nan: + if cache_ranks and (rsa_metric == "spearman"): + raise ValueError("ignore_nan and cache_ranks is not yet supported together") masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) @@ -433,6 +450,10 @@ def rsa_array( except AttributeError: pass + if cache_ranks and (rsa_metric == "spearman"): + # Precompute ranks for Spearman + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] + def rsa_single_patch(patch): """Compute RSA for a single searchlight patch.""" if len(X) == 1: # Check number of folds @@ -448,7 +469,7 @@ def rsa_single_patch(patch): patch_masks = [m & data_mask for m in masks] else: patch_masks = masks - return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks) + return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, cache_ranks=cache_ranks) # Call RSA multiple times in parallel for each searchlight patch data = Parallel(n_jobs=n_jobs)( From a796d541eec5f7eb6ac973fe6f0f2e160c1881de Mon Sep 17 00:00:00 2001 From: gydis Date: Thu, 10 Oct 2024 15:42:20 +0300 Subject: [PATCH 2/5] Update mne_rsa/rsa.py Co-authored-by: Marijn van Vliet --- mne_rsa/rsa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index cdac128..510155b 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -186,7 +186,7 @@ def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks): if cache_ranks: rdm_data = stats.rankdata(rdm_data) rsa_vals = [ - np.corrcoef(rdm_data, rdm_model_[mask])[0,1] + np.corrcoef(rdm_data, rdm_model_[mask])[0, 1] for rdm_model_, mask in zip(rdm_model, masks) ] else: From 3f642d69153f2484ad18c0e430daa122d5e77c65 Mon Sep 17 00:00:00 2001 From: gydis Date: Thu, 10 Oct 2024 16:49:06 +0300 Subject: [PATCH 3/5] Make rank caching to be the default if ignore_nan is False --- mne_rsa/rsa.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index 510155b..92675d4 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -109,7 +109,7 @@ def _partial_correlation(rdm_data, rdm_model, masks=None, type="pearson"): return -R_partial[0, 1:] -def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False, cache_ranks=True): +def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False): """Generate RSA values between data and model RDMs. Will yield RSA scores for each data RDM. @@ -158,32 +158,29 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False, cache_ if ignore_nan: - if cache_ranks and (rsa_metric == "spearman"): - raise ValueError("ignore_nan and cache_ranks is not yet supported together") masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) - - if cache_ranks and (rsa_metric == "spearman"): # Precompute ranks for Spearman - rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] + if metric == "spearman": + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] for rdm_data in rdm_data_gen: rdm_data = _ensure_condensed(rdm_data, "rdm_data") if ignore_nan: data_mask = ~np.isnan(rdm_data) masks = [m & data_mask for m in masks] - rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks) + rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan) if return_array: yield np.asarray(rsa_vals) else: yield rsa_vals[0] -def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks): +def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan): """Compute RSA between a single data RDM and one or more model RDMs.""" if metric == "spearman": - if cache_ranks: + if not ignore_nan: rdm_data = stats.rankdata(rdm_data) rsa_vals = [ np.corrcoef(rdm_data, rdm_model_[mask])[0, 1] @@ -326,7 +323,6 @@ def rsa_array( y=None, n_folds=1, n_jobs=1, - cache_ranks=True, verbose=False, ): """Perform RSA on an array of data, possibly in a searchlight pattern. @@ -434,11 +430,12 @@ def rsa_array( rdm_model = [_ensure_condensed(rdm_model, "rdm_model")] if ignore_nan: - if cache_ranks and (rsa_metric == "spearman"): - raise ValueError("ignore_nan and cache_ranks is not yet supported together") masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) + # Precompute ranks for Spearman + if rsa_metric == "spearman": + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] if verbose: from tqdm import tqdm @@ -450,10 +447,6 @@ def rsa_array( except AttributeError: pass - if cache_ranks and (rsa_metric == "spearman"): - # Precompute ranks for Spearman - rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] - def rsa_single_patch(patch): """Compute RSA for a single searchlight patch.""" if len(X) == 1: # Check number of folds @@ -469,7 +462,7 @@ def rsa_single_patch(patch): patch_masks = [m & data_mask for m in masks] else: patch_masks = masks - return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, cache_ranks=cache_ranks) + return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, ignore_nan) # Call RSA multiple times in parallel for each searchlight patch data = Parallel(n_jobs=n_jobs)( From e6ef68630945dd7092b7d54e3dc8ee44aa737bc2 Mon Sep 17 00:00:00 2001 From: gydis Date: Thu, 10 Oct 2024 16:51:35 +0300 Subject: [PATCH 4/5] Repalce scipy spearman by rankdata and numpy.corrcoef --- mne_rsa/rsa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index 92675d4..9c5ddab 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -188,7 +188,7 @@ def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan): ] else: rsa_vals = [ - stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0] + np.corrcoef(stats.rankdata(rdm_data), stats.rankdata(rdm_model_[mask]))[0, 1] for rdm_model_, mask in zip(rdm_model, masks) ] elif metric == "pearson": From 69fe23d2984fe605e70551c761b0b668257c7cea Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 15 Oct 2024 08:47:38 +0300 Subject: [PATCH 5/5] Fix ignore_nan --- mne_rsa/rsa.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index 9c5ddab..e9ded25 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -156,7 +156,6 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False): return_array = False rdm_model = [_ensure_condensed(rdm_model, "rdm_model")] - if ignore_nan: masks = [~np.isnan(rdm) for rdm in rdm_model] else: @@ -188,7 +187,9 @@ def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan): ] else: rsa_vals = [ - np.corrcoef(stats.rankdata(rdm_data), stats.rankdata(rdm_model_[mask]))[0, 1] + np.corrcoef( + stats.rankdata(rdm_data[mask]), stats.rankdata(rdm_model_[mask]) + )[0, 1] for rdm_model_, mask in zip(rdm_model, masks) ] elif metric == "pearson":