Skip to content

Commit

Permalink
Merge branch 'main' of github.com:wmvanvliet/mne-rsa
Browse files Browse the repository at this point in the history
  • Loading branch information
wmvanvliet committed Feb 12, 2025
2 parents 0334c18 + ed7cfd4 commit 09949e7
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions mne_rsa/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,38 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False):
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)
# Precompute ranks for Spearman
if metric == "spearman":
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

for rdm_data in rdm_data_gen:
rdm_data = _ensure_condensed(rdm_data, "rdm_data")
if ignore_nan:
data_mask = ~np.isnan(rdm_data)
masks = [m & data_mask for m in masks]
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks)
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan)
if return_array:
yield np.asarray(rsa_vals)
else:
yield rsa_vals[0]


def _rsa_single_rdm(rdm_data, rdm_model, metric, masks):
def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan):
"""Compute RSA between a single data RDM and one or more model RDMs."""
if metric == "spearman":
rsa_vals = [
stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0]
for rdm_model_, mask in zip(rdm_model, masks)
]
if not ignore_nan:
rdm_data = stats.rankdata(rdm_data)
rsa_vals = [
np.corrcoef(rdm_data, rdm_model_[mask])[0, 1]
for rdm_model_, mask in zip(rdm_model, masks)
]
else:
rsa_vals = [
np.corrcoef(
stats.rankdata(rdm_data[mask]), stats.rankdata(rdm_model_[mask])
)[0, 1]
for rdm_model_, mask in zip(rdm_model, masks)
]
elif metric == "pearson":
rsa_vals = [
stats.pearsonr(rdm_data[mask], rdm_model_[mask])[0]
Expand Down Expand Up @@ -423,6 +435,9 @@ def rsa_array(
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)
# Precompute ranks for Spearman
if rsa_metric == "spearman":
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

if verbose:
from tqdm import tqdm
Expand All @@ -449,7 +464,7 @@ def rsa_single_patch(patch):
patch_masks = [m & data_mask for m in masks]
else:
patch_masks = masks
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks)
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, ignore_nan)

# Call RSA multiple times in parallel for each searchlight patch
data = Parallel(n_jobs=n_jobs)(
Expand Down

0 comments on commit 09949e7

Please sign in to comment.