-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement Blank2014-linear benchmark; unify some methods with Pereira…
…2018
- Loading branch information
Showing
8 changed files
with
390 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
from typing import List | ||
|
||
from brainio.assemblies import walk_coords, array_is_element, DataAssembly | ||
|
||
|
||
def ci_error(samples, center, confidence=.95): | ||
low, high = 100 * ((1 - confidence) / 2), 100 * (1 - ((1 - confidence) / 2)) | ||
confidence_below, confidence_above = np.nanpercentile(samples, low), np.nanpercentile(samples, high) | ||
confidence_below, confidence_above = center - confidence_below, confidence_above - center | ||
return confidence_below, confidence_above | ||
|
||
|
||
def manual_merge(*elements: List[DataAssembly], on='neuroid') -> DataAssembly: | ||
""" | ||
Manually merge a set of assemblies where xarray's automated merge might fail. | ||
This function likely covers only covers a small number of use-cases, and should thus be used with caution. | ||
""" | ||
dims = elements[0].dims | ||
assert all(element.dims == dims for element in elements[1:]) | ||
merge_index = dims.index(on) | ||
# the coordinates in the merge index should have the same keys | ||
assert _coords_match(elements, dim=on, | ||
match_values=False), f"coords in {[element[on] for element in elements]} do not match" | ||
# all other dimensions, their coordinates and values should already align | ||
for dim in set(dims) - {on}: | ||
assert _coords_match(elements, dim=dim, | ||
match_values=True), f"coords in {[element[dim] for element in elements]} do not match" | ||
# merge values without meta | ||
merged_values = np.concatenate([element.values for element in elements], axis=merge_index) | ||
# piece together with meta | ||
result = type(elements[0])(merged_values, coords={ | ||
**{coord: (dims, values) | ||
for coord, dims, values in walk_coords(elements[0]) | ||
if not array_is_element(dims, on)}, | ||
**{coord: (dims, np.concatenate([element[coord].values for element in elements])) | ||
for coord, dims, _ in walk_coords(elements[0]) | ||
if array_is_element(dims, on)}}, dims=elements[0].dims) | ||
return result | ||
|
||
|
||
def _coords_match(elements, dim, match_values=False): | ||
""" Helper method for `manual_merge` """ | ||
first_coords = [(key, tuple(value)) if match_values else key for _, key, value in walk_coords(elements[0][dim])] | ||
other_coords = [[(key, tuple(value)) if match_values else key for _, key, value in walk_coords(element[dim])] | ||
for element in elements[1:]] | ||
return all(tuple(first_coords) == tuple(coords) for coords in other_coords) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from brainscore_language import benchmark_registry | ||
from .benchmark import Blank2014Linear | ||
|
||
benchmark_registry['Blank2014-linear'] = Blank2014Linear |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import xarray as xr | ||
|
||
from brainscore_core.benchmarks import BenchmarkBase | ||
from brainscore_core.metrics import Score | ||
from brainscore_language import load_dataset, load_metric | ||
from brainscore_language.artificial_subject import ArtificialSubject | ||
from brainscore_language.benchmarks.blank2014.ceiling import ExtrapolationCeiling | ||
from brainscore_language.data.blank2014 import BIBTEX | ||
from brainscore_language.utils.ceiling import ceiling_normalize | ||
|
||
|
||
class Blank2014Linear(BenchmarkBase): | ||
""" | ||
Evaluate model ability to predict neural activity in human language system functional regions of interest (fROIs) | ||
in response to natural stories, recorded by Blank et al. 2014. | ||
Alignment of neural activity between model and human subjects is evaluated via cross-validated linear predictivity. | ||
This benchmark builds off the Blank2014 benchmark introduced in Schrimpf et al. 2021 | ||
(https://www.pnas.org/doi/10.1073/pnas.2105646118), but requires the model to have committed to neural readouts | ||
(e.g. "layer 41 corresponds to the language system"), rather than testing every layer separately. | ||
""" | ||
|
||
def __init__(self): | ||
self.data = load_dataset('Blank2014.fROI') | ||
self.metric = load_metric('linear_pearsonr') | ||
ceiler = ExtrapolationCeiling() | ||
ceiling = ceiler(assembly=self.data, metric=self.metric) | ||
super(Blank2014Linear, self).__init__( | ||
identifier='Blank2014-linear', | ||
version=1, | ||
parent='neural_language', | ||
ceiling=ceiling, | ||
bibtex=BIBTEX) | ||
|
||
def __call__(self, candidate: ArtificialSubject) -> Score: | ||
candidate.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system, | ||
recording_type=ArtificialSubject.RecordingType.fMRI) | ||
stimuli = self.data['stimulus'] | ||
stories = self.data['story'].values | ||
predictions = [] | ||
for story in sorted(set(stories)): # go over individual stories, sorting to keep consistency across runs | ||
story_indexer = [stimulus_story == story for stimulus_story in stories] | ||
story_stimuli = stimuli[story_indexer] | ||
story_predictions = candidate.digest_text(story_stimuli.values)['neural'] | ||
story_predictions['stimulus_id'] = 'presentation', story_stimuli['stimulus_id'].values | ||
predictions.append(story_predictions) | ||
predictions = xr.concat(predictions, dim='presentation') | ||
raw_score = self.metric(predictions, self.data) | ||
score = ceiling_normalize(raw_score, self.ceiling) | ||
return score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import itertools | ||
import logging | ||
import numpy as np | ||
from numpy.random import RandomState | ||
from scipy.optimize import curve_fit | ||
from tqdm import tqdm, trange | ||
|
||
from brainio.assemblies import array_is_element, walk_coords, DataAssembly, merge_data_arrays | ||
from brainscore_core.metrics import Score | ||
from brainscore_language.benchmark_helpers import ci_error, manual_merge | ||
from brainscore_language.utils import fullname | ||
from brainscore_language.utils.transformations import apply_aggregate | ||
|
||
|
||
def v(x, v0, tau0): | ||
return v0 * (1 - np.exp(-x / tau0)) | ||
|
||
|
||
class ExtrapolationCeiling: | ||
def __init__(self, subject_column='subject_id', extrapolation_dimension='neuroid', num_bootstraps=100): | ||
self._logger = logging.getLogger(fullname(self)) | ||
self.subject_column = subject_column | ||
self.holdout_ceiling = HoldoutSubjectCeiling(subject_column=subject_column) | ||
self.extrapolation_dimension = extrapolation_dimension | ||
self.num_bootstraps = num_bootstraps | ||
|
||
def __call__(self, assembly, metric): | ||
scores = self.collect(assembly=assembly, metric=metric) | ||
return self.extrapolate(scores) | ||
|
||
def collect(self, assembly, metric): | ||
num_subjects = len(set(assembly[self.subject_column].values)) | ||
subject_subsamples = self.build_subject_subsamples(num_subjects) | ||
scores = [] | ||
for num_subjects in tqdm(subject_subsamples, desc='num subjects'): | ||
selection_combinations = self.iterate_subsets(assembly, num_subjects=num_subjects) | ||
for selections, sub_assembly in tqdm(selection_combinations, desc='selections'): | ||
score = self.holdout_ceiling(assembly=sub_assembly, metric=metric) | ||
score = score.expand_dims('num_subjects') | ||
score['num_subjects'] = [num_subjects] | ||
for key, selection in selections.items(): | ||
expand_dim = f'sub_{key}' | ||
score = score.expand_dims(expand_dim) | ||
score[expand_dim] = [str(selection)] | ||
scores.append(score.raw) | ||
scores = Score.merge(*scores) | ||
assert hasattr(scores, 'neuroid_id') | ||
return scores | ||
|
||
def build_subject_subsamples(self, num_subjects): | ||
return tuple(range(2, num_subjects + 1)) | ||
|
||
def iterate_subsets(self, assembly, num_subjects): | ||
subjects = set(assembly[self.subject_column].values) | ||
subject_combinations = list(itertools.combinations(sorted(subjects), num_subjects)) | ||
for sub_subjects in subject_combinations: | ||
sub_assembly = assembly[{'neuroid': [subject in sub_subjects | ||
for subject in assembly[self.subject_column].values]}] | ||
yield {self.subject_column: sub_subjects}, sub_assembly | ||
|
||
def average_collected(self, scores): | ||
return scores.median('neuroid') | ||
|
||
def extrapolate(self, ceilings): | ||
neuroid_ceilings, bootstrap_params, endpoint_xs = [], [], [] | ||
for i in trange(len(ceilings[self.extrapolation_dimension]), | ||
desc=f'{self.extrapolation_dimension} extrapolations'): | ||
# extrapolate per-neuroid ceiling | ||
neuroid_ceiling = ceilings.isel(**{self.extrapolation_dimension: [i]}) | ||
extrapolated_ceiling = self.extrapolate_neuroid(neuroid_ceiling.squeeze()) | ||
extrapolated_ceiling = self.add_neuroid_meta(extrapolated_ceiling, neuroid_ceiling) | ||
neuroid_ceilings.append(extrapolated_ceiling) | ||
# also keep track of bootstrapped parameters | ||
neuroid_bootstrap_params = extrapolated_ceiling.bootstrapped_params | ||
neuroid_bootstrap_params = self.add_neuroid_meta(neuroid_bootstrap_params, neuroid_ceiling) | ||
bootstrap_params.append(neuroid_bootstrap_params) | ||
# and endpoints | ||
endpoint_x = self.add_neuroid_meta(extrapolated_ceiling.endpoint_x, neuroid_ceiling) | ||
endpoint_xs.append(endpoint_x) | ||
# merge and add meta | ||
self._logger.debug("Merging neuroid ceilings") | ||
neuroid_ceilings = manual_merge(*neuroid_ceilings, on=self.extrapolation_dimension) | ||
neuroid_ceilings.attrs['raw'] = ceilings | ||
self._logger.debug("Merging bootstrap params") | ||
bootstrap_params = manual_merge(*bootstrap_params, on=self.extrapolation_dimension) | ||
neuroid_ceilings.attrs['bootstrapped_params'] = bootstrap_params | ||
self._logger.debug("Merging endpoints") | ||
endpoint_xs = manual_merge(*endpoint_xs, on=self.extrapolation_dimension) | ||
neuroid_ceilings.attrs['endpoint_x'] = endpoint_xs | ||
# aggregate | ||
ceiling = self.aggregate_neuroid_ceilings(neuroid_ceilings) | ||
return ceiling | ||
|
||
def add_neuroid_meta(self, target, source): | ||
target = target.expand_dims(self.extrapolation_dimension) | ||
for coord, dims, values in walk_coords(source): | ||
if array_is_element(dims, self.extrapolation_dimension): | ||
target[coord] = dims, values | ||
return target | ||
|
||
def aggregate_neuroid_ceilings(self, neuroid_ceilings): | ||
ceiling = neuroid_ceilings.median(self.extrapolation_dimension) | ||
ceiling.attrs['bootstrapped_params'] = neuroid_ceilings.bootstrapped_params.median(self.extrapolation_dimension) | ||
ceiling.attrs['endpoint_x'] = neuroid_ceilings.endpoint_x.median(self.extrapolation_dimension) | ||
ceiling.attrs['raw'] = neuroid_ceilings | ||
return ceiling | ||
|
||
def extrapolate_neuroid(self, ceilings): | ||
# figure out how many extrapolation x points we have. E.g. for Pereira, not all combinations are possible | ||
subject_subsamples = list(sorted(set(ceilings['num_subjects'].values))) | ||
rng = RandomState(0) | ||
bootstrap_params = [] | ||
for bootstrap in range(self.num_bootstraps): | ||
bootstrapped_scores = [] | ||
for num_subjects in subject_subsamples: | ||
num_scores = ceilings.sel(num_subjects=num_subjects) | ||
# the sub_subjects dimension creates nans, get rid of those | ||
num_scores = num_scores.dropna(f'sub_{self.subject_column}') | ||
assert set(num_scores.dims) == {f'sub_{self.subject_column}', 'split'} or \ | ||
set(num_scores.dims) == {f'sub_{self.subject_column}'} | ||
# choose from subject subsets and the splits therein, with replacement for variance | ||
choices = num_scores.values.flatten() | ||
bootstrapped_score = rng.choice(choices, size=len(choices), replace=True) | ||
bootstrapped_scores.append(np.mean(bootstrapped_score)) | ||
|
||
params = self.fit(subject_subsamples, bootstrapped_scores) | ||
params = DataAssembly([params], coords={'bootstrap': [bootstrap], 'param': ['v0', 'tau0']}, | ||
dims=['bootstrap', 'param']) | ||
bootstrap_params.append(params) | ||
bootstrap_params = merge_data_arrays(bootstrap_params) | ||
# find endpoint and error | ||
asymptote_threshold = .0005 | ||
interpolation_xs = np.arange(1000) | ||
ys = np.array([v(interpolation_xs, *params) for params in bootstrap_params.values | ||
if not np.isnan(params).any()]) | ||
median_ys = np.median(ys, axis=0) | ||
diffs = np.diff(median_ys) | ||
end_x = np.where(diffs < asymptote_threshold)[0].min() # first x where increase smaller than threshold | ||
# put together | ||
center = np.median(np.array(bootstrap_params)[:, 0]) | ||
error_low, error_high = ci_error(ys[:, end_x], center=center) | ||
score = Score(center) | ||
score.attrs['error_low'] = error_low | ||
score.attrs['error_high'] = error_high | ||
score.attrs['raw'] = ceilings | ||
score.attrs['bootstrapped_params'] = bootstrap_params | ||
score.attrs['endpoint_x'] = DataAssembly(end_x) | ||
return score | ||
|
||
def fit(self, subject_subsamples, bootstrapped_scores): | ||
params, pcov = curve_fit(v, subject_subsamples, bootstrapped_scores, | ||
# v (i.e. max ceiling) is between 0 and 1, tau0 unconstrained | ||
bounds=([0, -np.inf], [1, np.inf])) | ||
return params | ||
|
||
|
||
class HoldoutSubjectCeiling: | ||
def __init__(self, subject_column): | ||
self.subject_column = subject_column | ||
self._logger = logging.getLogger(fullname(self)) | ||
|
||
def __call__(self, assembly, metric): | ||
subjects = set(assembly[self.subject_column].values) | ||
scores = [] | ||
iterate_subjects = self.get_subject_iterations(subjects) | ||
for subject in tqdm(iterate_subjects, desc='heldout subject'): | ||
try: | ||
subject_assembly = assembly[{'neuroid': [subject_value == subject | ||
for subject_value in assembly[self.subject_column].values]}] | ||
# run subject pool as neural candidate | ||
subject_pool = subjects - {subject} | ||
pool_assembly = assembly[ | ||
{'neuroid': [subject in subject_pool for subject in assembly[self.subject_column].values]}] | ||
score = self.score(pool_assembly, subject_assembly, metric=metric) | ||
# store scores | ||
apply_raw = 'raw' in score.attrs and \ | ||
not hasattr(score.raw, self.subject_column) # only propagate if column not part of score | ||
score = score.expand_dims(self.subject_column, _apply_raw=apply_raw) | ||
score.__setitem__(self.subject_column, [subject], _apply_raw=apply_raw) | ||
scores.append(score) | ||
except NoOverlapException as e: | ||
self._logger.debug(f"Ignoring no overlap {e}") | ||
continue # ignore | ||
except ValueError as e: | ||
if "Found array with" in str(e): | ||
self._logger.debug(f"Ignoring empty array {e}") | ||
continue | ||
else: | ||
raise e | ||
|
||
scores = Score.merge(*scores) | ||
score = apply_aggregate(lambda scores: scores.mean(self.subject_column), scores) | ||
score.attrs['error'] = scores.std(self.subject_column) | ||
return scores | ||
|
||
def get_subject_iterations(self, subjects): | ||
return subjects # iterate over all subjects | ||
|
||
def score(self, pool_assembly, subject_assembly, metric): | ||
return metric(pool_assembly, subject_assembly) | ||
|
||
|
||
class NoOverlapException(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import copy | ||
import numpy as np | ||
from numpy.random import RandomState | ||
from pytest import approx | ||
from typing import Callable, Union, List | ||
|
||
from brainio.assemblies import NeuroidAssembly | ||
from brainscore_language import ArtificialSubject, load_benchmark | ||
|
||
|
||
class TestBenchmark: | ||
class DummyModel(ArtificialSubject): | ||
def __init__(self, activity_for_text: Callable[[Union[str, List[str]]], NeuroidAssembly]): | ||
self.activity_for_text = activity_for_text | ||
|
||
def digest_text(self, stimuli): | ||
neural_activity = self.activity_for_text(stimuli) | ||
return {'neural': neural_activity} | ||
|
||
def start_neural_recording(self, recording_target: ArtificialSubject.RecordingTarget, | ||
recording_type: ArtificialSubject.RecordingType): | ||
assert recording_target == ArtificialSubject.RecordingTarget.language_system | ||
assert recording_type == ArtificialSubject.RecordingType.fMRI | ||
|
||
def test_dummy_bad(self): | ||
random_state = RandomState(0) | ||
|
||
def activity_for_text(stimuli: Union[str, List[str]]) -> NeuroidAssembly: | ||
num_stimuli = len(stimuli) | ||
num_neuroids = 25 | ||
neural_activity = random_state.random(size=(num_stimuli, num_neuroids)) # presentation x neuroid | ||
neural_activity = NeuroidAssembly(neural_activity, | ||
coords={'stimulus_seq': ('presentation', np.arange(num_stimuli)), | ||
'stimulus_num': ('presentation', np.arange(num_stimuli)), | ||
'neuroid_id': ('neuroid', np.arange(num_neuroids)), | ||
'region': ('neuroid', ['some_region'] * num_neuroids)}, | ||
dims=['presentation', 'neuroid']) | ||
neural_activity['stimulus'] = 'presentation', stimuli # copy over | ||
return neural_activity | ||
|
||
benchmark = load_benchmark('Blank2014-linear') | ||
dummy_model = TestBenchmark.DummyModel(activity_for_text=activity_for_text) | ||
score = benchmark(dummy_model) | ||
assert score == 0 | ||
|
||
def test_exact(self): | ||
benchmark = load_benchmark('Blank2014-linear') | ||
exact_data = copy.deepcopy(benchmark.data) | ||
|
||
def activity_for_text(stimuli: Union[str, List[str]]) -> NeuroidAssembly: | ||
passage_activity = exact_data[{'presentation': [ | ||
list(exact_data['stimulus'].values).index(stimulus) for stimulus in stimuli]}] | ||
# remove stimulus_id and stimulus coordinates to not trip up benchmark | ||
passage_activity = passage_activity.reset_index('presentation') | ||
del passage_activity['stimulus_id'] | ||
passage_activity = NeuroidAssembly(passage_activity) # index | ||
return passage_activity | ||
|
||
dummy_model = TestBenchmark.DummyModel(activity_for_text=activity_for_text) | ||
score = benchmark(dummy_model) | ||
assert score == approx(1) | ||
|
||
def test_ceiling(self): | ||
benchmark = load_benchmark(f'Blank2014-linear') | ||
ceiling = benchmark.ceiling | ||
assert ceiling == approx(.21026591, abs=.0005) | ||
|
||
def test_ceiling_raw(self): | ||
benchmark = load_benchmark(f'Blank2014-linear') | ||
ceiling = benchmark.ceiling | ||
assert hasattr(ceiling, 'raw') | ||
assert set(ceiling.raw.dims) == {'neuroid'} | ||
assert ceiling.raw.median() == ceiling | ||
assert hasattr(ceiling.raw, 'raw') | ||
assert set(ceiling.raw.raw.dims) == {'sub_subject_id', 'num_subjects', 'split', 'neuroid'} |
Oops, something went wrong.