Skip to content

Commit 22ca801

Browse files
Merge pull request #5 from ContextLab/rev-1
PNAS round 2
2 parents b74d368 + c5bf957 commit 22ca801

File tree

127 files changed

+5031
-1301
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+5031
-1301
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,4 @@ data/pieman2_htfa.pkl
171171
data/neurosynth/*.xlsx
172172
data/neurosynth/Topic_Labels.csv
173173
paper/figs/source/neurosynth_topics.aux
174+
data/scratch/decoding_results_shuffled_*.pkl

code/notebooks/decoding_and_compression.ipynb

Lines changed: 818 additions & 106 deletions
Large diffs are not rendered by default.

code/notebooks/helpers.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.decomposition import IncrementalPCA as PCA
1010
from tqdm import tqdm
1111
from scipy.spatial.distance import cdist
12+
import statsmodels.api as sm
1213

1314
import os
1415
import warnings
@@ -100,10 +101,6 @@ def fullfact(dims):
100101
vox_coords = fullfact(img.shape[0:3])[vmask, ::-1]-1
101102

102103
R = np.array(np.dot(vox_coords, S[0:3, 0:3])) + S[:3, 3]
103-
104-
# center on the MNI152 brain (hard code this in)
105-
#mni_center = np.array([0.55741881, -21.52140703, 9.83783098])
106-
#R = R - R.mean(axis=0) + mni_center
107104

108105
return {'Y': Y, 'R': R}
109106

@@ -208,13 +205,39 @@ def cross_validation(data, n_iter=10, fname=None, max_components=700):
208205
return results
209206

210207

211-
def ridge_plot(x, column='Number of components', fname=None, xlim=[-99, 700], hue='Condition', palette=[condition_colors[c] for c in conditions]):
208+
def ridge_plot(x, column='Number of components', fname=None, xlim=[-99, 700], hue='Condition', palette=[condition_colors[c] for c in conditions], scale_start=0.25, scale_height=0.1):
209+
def pdf_plot(x, ax=None, xlim=[0.0, 1.0], resolution=1000, **kwargs):
210+
if ax is None:
211+
ax = plt.gca()
212+
213+
if 'color' in kwargs:
214+
color = kwargs['color']
215+
kwargs.pop('color')
216+
else:
217+
color = 'k'
218+
219+
density = sm.nonparametric.KDEUnivariate(x)
220+
density.fit()
221+
222+
xs = np.linspace(xlim[0], xlim[1], resolution)
223+
ys = density.evaluate(xs)
224+
ys = ys / np.sum(ys)
225+
226+
ax.fill(xs, ys, color=color, **kwargs)
227+
return ax
212228

213229
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
214230
g = sns.FacetGrid(x, row=hue, hue=hue, palette=palette, height=1, aspect=6)
215-
g.map(sns.kdeplot, column, bw_adjust=1, clip_on=True, fill=True, alpha=1, common_norm=True, linewidth=1.5)
231+
# g.map(sns.kdeplot, column, bw_adjust=1, clip_on=True, fill=True, alpha=1, common_norm=True, linewidth=1.5) # REPLACE WITH A NEW FUNCTION THAT NORMALIZES AREA TO 1
232+
g.map(pdf_plot, column, xlim=xlim, linewidth=1.5)
216233
g.refline(y=0, linewidth=1.5, linestyle='-', color=None, clip_on=False)
217234

235+
# plot a scale bar in the upper right
236+
if scale_height is not None:
237+
# compute the x position of the scale bar -- 98% of the way to the right
238+
x = xlim[0] + 0.98 * (xlim[1] - xlim[0])
239+
g.axes[0][0].plot([x, x], [scale_start, scale_start + scale_height], color='k', linewidth=1.5)
240+
218241
def label(x, color, label):
219242
ax = plt.gca()
220243
ax.text(0, 0.2, label.capitalize(), color=color, ha='left', va='center', transform=ax.transAxes)
@@ -231,7 +254,7 @@ def label(x, color, label):
231254

232255
ax = plt.gca()
233256
ax.set_xlim(xlim[0], xlim[1])
234-
ax.set_xlabel(column, fontsize=12)
257+
ax.set_xlabel(column, fontsize=12)
235258

236259
if fname is not None:
237260
g.savefig(os.path.join(figdir, fname + '.pdf'), bbox_inches='tight')
@@ -254,7 +277,7 @@ def get_data():
254277
return data
255278

256279

257-
def info_and_compressibility(d, target=0.05):
280+
def info_and_compressibility(d, target=None):
258281
def closest(x, target):
259282
dists = np.abs(x.values - target)
260283
dists[x.values < target] += 10 * np.max(dists)
@@ -264,19 +287,23 @@ def closest(x, target):
264287
for c in conditions:
265288
dc = d[c].astype(float).pivot(index='Iteration', columns='Number of components', values='Relative decoding accuracy')
266289
i = pd.DataFrame()
267-
i['Number of components'] = dc.apply(lambda x: closest(x, target), axis=1, raw=False)
290+
291+
if target is None:
292+
i['Number of components'] = dc.idxmax(axis=1).astype(int)
293+
else:
294+
i['Number of components'] = dc.apply(lambda x: closest(x, target), axis=1, raw=False)
268295
i['Relative decoding accuracy'] = dc.max(axis=1)
269296
i['Condition'] = c
270297
i['Iteration'] = dc.index.values.astype(int)
271298
df.append(i)
272299
return pd.concat(df, ignore_index=True, axis=0)
273300

274301

275-
def plot_info_and_compressibility_scatter(x, fname=None):
302+
def plot_info_and_compressibility_scatter(x, fname=None, target=None):
276303
fig = plt.figure(figsize=(4, 3))
277304
ax = plt.gca()
278305

279-
x = info_and_compressibility(x)
306+
x = info_and_compressibility(x, target=target)
280307
sns.scatterplot(x, x='Number of components', y='Relative decoding accuracy', hue='Condition', palette=[condition_colors[c] for c in conditions], legend=False, s=10, ax=ax)
281308
sns.scatterplot(x.groupby('Condition').mean().loc[conditions].reset_index(), x='Number of components', y='Relative decoding accuracy', hue='Condition', palette=[condition_colors[c] for c in conditions], legend=False, s=100, ax=ax)
282309

@@ -293,4 +320,36 @@ def plot_info_and_compressibility_scatter(x, fname=None):
293320
return fig
294321

295322
def rbf(R, center, width):
296-
return np.exp(-np.sum((R - center) ** 2, axis=1) / width)
323+
return np.exp(-np.sum((R - center) ** 2, axis=1) / width)
324+
325+
326+
def plot_accuracy(x, figdir=None, fname=None, conditions=['intact', 'paragraph', 'word', 'rest'], condition_colors=condition_colors, ylim=[-0.01, 0.35], xlim=[3, 700], ax=None):
327+
if figdir is not None and not os.path.exists(figdir):
328+
os.makedirs(figdir)
329+
330+
if ax is None:
331+
fig = plt.figure(figsize=(4, 3))
332+
ax = plt.gca()
333+
else:
334+
fig = plt.gcf()
335+
336+
for c in conditions:
337+
sns.lineplot(x[c], x='Number of components', y='Relative decoding accuracy', label=c.capitalize(), color=condition_colors[c], legend=False, ax=ax)
338+
339+
ax.set_xlabel('Number of components', fontsize=12)
340+
ax.set_ylabel('Relative decoding accuracy', fontsize=12)
341+
ax.set_ylim(ylim)
342+
ax.set_xlim(xlim)
343+
ax.spines[['right', 'top']].set_visible(False)
344+
345+
if fname is not None:
346+
fig.savefig(os.path.join(figdir, fname + '.pdf'), bbox_inches='tight')
347+
348+
return fig
349+
350+
351+
def pstring(pval):
352+
if pval < 0.001:
353+
return 'p < 0.001'
354+
else:
355+
return f'p = {pval:.3f}'

code/notebooks/network_analyses.ipynb

Lines changed: 872 additions & 242 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)