Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching of ranks for computing RSA with Spearman #34

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions mne_rsa/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,38 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False):
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)
# Precompute ranks for Spearman
if metric == "spearman":
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

for rdm_data in rdm_data_gen:
rdm_data = _ensure_condensed(rdm_data, "rdm_data")
if ignore_nan:
data_mask = ~np.isnan(rdm_data)
masks = [m & data_mask for m in masks]
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks)
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan)
if return_array:
yield np.asarray(rsa_vals)
else:
yield rsa_vals[0]


def _rsa_single_rdm(rdm_data, rdm_model, metric, masks):
def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan):
"""Compute RSA between a single data RDM and one or more model RDMs."""
if metric == "spearman":
rsa_vals = [
stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0]
for rdm_model_, mask in zip(rdm_model, masks)
]
if not ignore_nan:
rdm_data = stats.rankdata(rdm_data)
rsa_vals = [
np.corrcoef(rdm_data, rdm_model_[mask])[0, 1]
for rdm_model_, mask in zip(rdm_model, masks)
]
else:
rsa_vals = [
np.corrcoef(
stats.rankdata(rdm_data[mask]), stats.rankdata(rdm_model_[mask])
)[0, 1]
for rdm_model_, mask in zip(rdm_model, masks)
]
elif metric == "pearson":
rsa_vals = [
stats.pearsonr(rdm_data[mask], rdm_model_[mask])[0]
gydis marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -423,6 +435,9 @@ def rsa_array(
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)
# Precompute ranks for Spearman
if rsa_metric == "spearman":
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

if verbose:
from tqdm import tqdm
Expand All @@ -449,7 +464,7 @@ def rsa_single_patch(patch):
patch_masks = [m & data_mask for m in masks]
else:
patch_masks = masks
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks)
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, ignore_nan)

# Call RSA multiple times in parallel for each searchlight patch
data = Parallel(n_jobs=n_jobs)(
Expand Down