9
9
from sklearn .decomposition import IncrementalPCA as PCA
10
10
from tqdm import tqdm
11
11
from scipy .spatial .distance import cdist
12
+ import statsmodels .api as sm
12
13
13
14
import os
14
15
import warnings
@@ -100,10 +101,6 @@ def fullfact(dims):
100
101
vox_coords = fullfact (img .shape [0 :3 ])[vmask , ::- 1 ]- 1
101
102
102
103
R = np .array (np .dot (vox_coords , S [0 :3 , 0 :3 ])) + S [:3 , 3 ]
103
-
104
- # center on the MNI152 brain (hard code this in)
105
- #mni_center = np.array([0.55741881, -21.52140703, 9.83783098])
106
- #R = R - R.mean(axis=0) + mni_center
107
104
108
105
return {'Y' : Y , 'R' : R }
109
106
@@ -208,13 +205,39 @@ def cross_validation(data, n_iter=10, fname=None, max_components=700):
208
205
return results
209
206
210
207
211
- def ridge_plot (x , column = 'Number of components' , fname = None , xlim = [- 99 , 700 ], hue = 'Condition' , palette = [condition_colors [c ] for c in conditions ]):
208
+ def ridge_plot (x , column = 'Number of components' , fname = None , xlim = [- 99 , 700 ], hue = 'Condition' , palette = [condition_colors [c ] for c in conditions ], scale_start = 0.25 , scale_height = 0.1 ):
209
+ def pdf_plot (x , ax = None , xlim = [0.0 , 1.0 ], resolution = 1000 , ** kwargs ):
210
+ if ax is None :
211
+ ax = plt .gca ()
212
+
213
+ if 'color' in kwargs :
214
+ color = kwargs ['color' ]
215
+ kwargs .pop ('color' )
216
+ else :
217
+ color = 'k'
218
+
219
+ density = sm .nonparametric .KDEUnivariate (x )
220
+ density .fit ()
221
+
222
+ xs = np .linspace (xlim [0 ], xlim [1 ], resolution )
223
+ ys = density .evaluate (xs )
224
+ ys = ys / np .sum (ys )
225
+
226
+ ax .fill (xs , ys , color = color , ** kwargs )
227
+ return ax
212
228
213
229
sns .set_theme (style = "white" , rc = {"axes.facecolor" : (0 , 0 , 0 , 0 )})
214
230
g = sns .FacetGrid (x , row = hue , hue = hue , palette = palette , height = 1 , aspect = 6 )
215
- g .map (sns .kdeplot , column , bw_adjust = 1 , clip_on = True , fill = True , alpha = 1 , common_norm = True , linewidth = 1.5 )
231
+ # g.map(sns.kdeplot, column, bw_adjust=1, clip_on=True, fill=True, alpha=1, common_norm=True, linewidth=1.5) # REPLACE WITH A NEW FUNCTION THAT NORMALIZES AREA TO 1
232
+ g .map (pdf_plot , column , xlim = xlim , linewidth = 1.5 )
216
233
g .refline (y = 0 , linewidth = 1.5 , linestyle = '-' , color = None , clip_on = False )
217
234
235
+ # plot a scale bar in the upper right
236
+ if scale_height is not None :
237
+ # compute the x position of the scale bar -- 98% of the way to the right
238
+ x = xlim [0 ] + 0.98 * (xlim [1 ] - xlim [0 ])
239
+ g .axes [0 ][0 ].plot ([x , x ], [scale_start , scale_start + scale_height ], color = 'k' , linewidth = 1.5 )
240
+
218
241
def label (x , color , label ):
219
242
ax = plt .gca ()
220
243
ax .text (0 , 0.2 , label .capitalize (), color = color , ha = 'left' , va = 'center' , transform = ax .transAxes )
@@ -231,7 +254,7 @@ def label(x, color, label):
231
254
232
255
ax = plt .gca ()
233
256
ax .set_xlim (xlim [0 ], xlim [1 ])
234
- ax .set_xlabel (column , fontsize = 12 )
257
+ ax .set_xlabel (column , fontsize = 12 )
235
258
236
259
if fname is not None :
237
260
g .savefig (os .path .join (figdir , fname + '.pdf' ), bbox_inches = 'tight' )
@@ -254,7 +277,7 @@ def get_data():
254
277
return data
255
278
256
279
257
- def info_and_compressibility (d , target = 0.05 ):
280
+ def info_and_compressibility (d , target = None ):
258
281
def closest (x , target ):
259
282
dists = np .abs (x .values - target )
260
283
dists [x .values < target ] += 10 * np .max (dists )
@@ -264,19 +287,23 @@ def closest(x, target):
264
287
for c in conditions :
265
288
dc = d [c ].astype (float ).pivot (index = 'Iteration' , columns = 'Number of components' , values = 'Relative decoding accuracy' )
266
289
i = pd .DataFrame ()
267
- i ['Number of components' ] = dc .apply (lambda x : closest (x , target ), axis = 1 , raw = False )
290
+
291
+ if target is None :
292
+ i ['Number of components' ] = dc .idxmax (axis = 1 ).astype (int )
293
+ else :
294
+ i ['Number of components' ] = dc .apply (lambda x : closest (x , target ), axis = 1 , raw = False )
268
295
i ['Relative decoding accuracy' ] = dc .max (axis = 1 )
269
296
i ['Condition' ] = c
270
297
i ['Iteration' ] = dc .index .values .astype (int )
271
298
df .append (i )
272
299
return pd .concat (df , ignore_index = True , axis = 0 )
273
300
274
301
275
- def plot_info_and_compressibility_scatter (x , fname = None ):
302
+ def plot_info_and_compressibility_scatter (x , fname = None , target = None ):
276
303
fig = plt .figure (figsize = (4 , 3 ))
277
304
ax = plt .gca ()
278
305
279
- x = info_and_compressibility (x )
306
+ x = info_and_compressibility (x , target = target )
280
307
sns .scatterplot (x , x = 'Number of components' , y = 'Relative decoding accuracy' , hue = 'Condition' , palette = [condition_colors [c ] for c in conditions ], legend = False , s = 10 , ax = ax )
281
308
sns .scatterplot (x .groupby ('Condition' ).mean ().loc [conditions ].reset_index (), x = 'Number of components' , y = 'Relative decoding accuracy' , hue = 'Condition' , palette = [condition_colors [c ] for c in conditions ], legend = False , s = 100 , ax = ax )
282
309
@@ -293,4 +320,36 @@ def plot_info_and_compressibility_scatter(x, fname=None):
293
320
return fig
294
321
295
322
def rbf (R , center , width ):
296
- return np .exp (- np .sum ((R - center ) ** 2 , axis = 1 ) / width )
323
+ return np .exp (- np .sum ((R - center ) ** 2 , axis = 1 ) / width )
324
+
325
+
326
+ def plot_accuracy (x , figdir = None , fname = None , conditions = ['intact' , 'paragraph' , 'word' , 'rest' ], condition_colors = condition_colors , ylim = [- 0.01 , 0.35 ], xlim = [3 , 700 ], ax = None ):
327
+ if figdir is not None and not os .path .exists (figdir ):
328
+ os .makedirs (figdir )
329
+
330
+ if ax is None :
331
+ fig = plt .figure (figsize = (4 , 3 ))
332
+ ax = plt .gca ()
333
+ else :
334
+ fig = plt .gcf ()
335
+
336
+ for c in conditions :
337
+ sns .lineplot (x [c ], x = 'Number of components' , y = 'Relative decoding accuracy' , label = c .capitalize (), color = condition_colors [c ], legend = False , ax = ax )
338
+
339
+ ax .set_xlabel ('Number of components' , fontsize = 12 )
340
+ ax .set_ylabel ('Relative decoding accuracy' , fontsize = 12 )
341
+ ax .set_ylim (ylim )
342
+ ax .set_xlim (xlim )
343
+ ax .spines [['right' , 'top' ]].set_visible (False )
344
+
345
+ if fname is not None :
346
+ fig .savefig (os .path .join (figdir , fname + '.pdf' ), bbox_inches = 'tight' )
347
+
348
+ return fig
349
+
350
+
351
+ def pstring (pval ):
352
+ if pval < 0.001 :
353
+ return 'p < 0.001'
354
+ else :
355
+ return f'p = { pval :.3f} '
0 commit comments