Skip to content

Commit

Permalink
Merge pull request #121 from vathes/master
Browse files Browse the repository at this point in the history
report and publication fixes
  • Loading branch information
LiuDaveLiu authored Nov 18, 2019
2 parents 69f1d82 + 770b015 commit 5736578
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 122 deletions.
5 changes: 5 additions & 0 deletions pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
log = logging.getLogger(__name__)


# safe-guard in case `custom` is not provided
if 'custom' not in dj.config:
dj.config['custom'] = {}


def get_schema_name(name):
try:
return dj.config['custom']['{}.database'.format(name)]
Expand Down
23 changes: 14 additions & 9 deletions pipeline/fixes/fix_0002_delay_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def fix_session(session_key):
if len(filelist) != len(files):
log.warning("behavior files missing in {} ({}/{}). skipping".format(
session_key, len(filelist), len(files)))
return
return False

log.info('filelist: {}'.format(filelist))

Expand Down Expand Up @@ -126,7 +126,7 @@ def fix_session(session_key):
# all files were internally invalid or size < 100k
if not trials:
log.warning('skipping ., no valid files')
return
return False

key = session_key
skey = (experiment.Session & key).fetch1()
Expand Down Expand Up @@ -538,6 +538,8 @@ def fix_session(session_key):
rows['corrected_trial_event'], ignore_extra_fields=True,
allow_direct_insert=True)

return True


def verify_session(s):
log.info('verifying_session {}'.format(s))
Expand All @@ -559,23 +561,23 @@ def note_prob(s, e, msg):

if newstate == 'presample':
if state and state not in {'presample', 'trialend'}:
note_prob(s, e)
note_prob(s, e, 'trialend !-> presample')
nerr += 1
if newstate == 'sample':
if state and state not in {'presample', 'sample'}:
note_prob(s, e)
note_prob(s, e, 'presaple !-> sample')
nerr += 1
if newstate == 'delay':
if state and state not in {'sample', 'delay'}:
note_prob(s, e)
note_prob(s, e, 'sample !-> delay')
nerr += 1
if newstate == 'go':
if state and state not in {'delay', 'go'}:
note_prob(s, e)
note_prob(s, e, 'delay !-> go')
nerr += 1
if newstate == 'trialend':
if state and state not in {'go', 'trialend'}:
note_prob(s, e)
note_prob(s, e, 'go !-> trialend')
nerr += 1

eid, state = neweid, newstate
Expand All @@ -585,6 +587,7 @@ def note_prob(s, e, msg):
else:
log.warning('session {} had {} verification errors.'.format(s, nerr))


def fix_0002_delay_events():
with dj.conn().transaction:

Expand All @@ -596,8 +599,10 @@ def fix_0002_delay_events():
q = (experiment.Session & behavior_ingest.BehaviorIngest)

for s in q.fetch('KEY'):
fix_session(s)
verify_session(s)
if fix_session(s):
verify_session(s)
else:
log.warning('session {} verify skipped - not fixed'.format(s))


if __name__ == '__main__':
Expand Down
113 changes: 73 additions & 40 deletions pipeline/plot/unit_characteristic_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,25 @@
from pipeline import experiment, ephys, psth

from pipeline.plot.util import (_plot_with_sem, _extract_one_stim_dur, _get_units_hemisphere,
_plot_stacked_psth_diff, _plot_avg_psth,
jointplot_w_hue)
_get_trial_event_times, _get_clustering_method,
_plot_stacked_psth_diff, _plot_avg_psth, jointplot_w_hue)

m_scale = 1200
_plt_xmin = -3
_plt_xmax = 2


def plot_clustering_quality(probe_insertion, axs=None):
def plot_clustering_quality(probe_insertion, clustering_method=None, axs=None):
probe_insertion = probe_insertion.proj()
amp, snr, spk_rate, isi_violation = (ephys.Unit * ephys.UnitStat
* ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch(

if clustering_method is None:
try:
clustering_method = _get_clustering_method(probe_insertion)
except ValueError as e:
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

amp, snr, spk_rate, isi_violation = (ephys.Unit * ephys.UnitStat * ephys.ProbeInsertion.InsertionLocation
& probe_insertion & {'clustering_method': clustering_method}).fetch(
'unit_amp', 'unit_snr', 'avg_firing_rate', 'isi_violation')

metrics = {'amp': amp,
Expand Down Expand Up @@ -52,11 +59,18 @@ def plot_clustering_quality(probe_insertion, axs=None):
return fig


def plot_unit_characteristic(probe_insertion, axs=None):
def plot_unit_characteristic(probe_insertion, clustering_method=None, axs=None):
probe_insertion = probe_insertion.proj()

if clustering_method is None:
try:
clustering_method = _get_clustering_method(probe_insertion)
except ValueError as e:
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

amp, snr, spk_rate, x, y, insertion_depth = (
ephys.Unit * ephys.ProbeInsertion.InsertionLocation * ephys.UnitStat
& probe_insertion & 'unit_quality != "all"').fetch(
& probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"').fetch(
'unit_amp', 'unit_snr', 'avg_firing_rate', 'unit_posx', 'unit_posy', 'dv_location')

insertion_depth = np.where(np.isnan(insertion_depth), 0, insertion_depth)
Expand Down Expand Up @@ -102,12 +116,20 @@ def plot_unit_characteristic(probe_insertion, axs=None):
return fig


def plot_unit_selectivity(probe_insertion, axs=None):
def plot_unit_selectivity(probe_insertion, clustering_method=None, axs=None):
probe_insertion = probe_insertion.proj()

if clustering_method is None:
try:
clustering_method = _get_clustering_method(probe_insertion)
except ValueError as e:
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

attr_names = ['unit', 'period', 'period_selectivity', 'contra_firing_rate',
'ipsi_firing_rate', 'unit_posx', 'unit_posy', 'dv_location']
'ipsi_firing_rate', 'unit_posx', 'unit_posy', 'dv_location']
selective_units = (psth.PeriodSelectivity * ephys.Unit * ephys.ProbeInsertion.InsertionLocation
* experiment.Period & probe_insertion & 'period_selectivity != "non-selective"').fetch(*attr_names)
* experiment.Period & probe_insertion & {'clustering_method': clustering_method}
& 'period_selectivity != "non-selective"').fetch(*attr_names)
selective_units = pd.DataFrame(selective_units).T
selective_units.columns = attr_names
selective_units.period_selectivity.astype('category')
Expand Down Expand Up @@ -162,10 +184,16 @@ def plot_unit_selectivity(probe_insertion, axs=None):
return fig


def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None):
probe_insertion = probe_insertion.proj()

if clustering_method is None:
try:
clustering_method = _get_clustering_method(probe_insertion)
except ValueError as e:
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('dv_location')
cue_onset = (experiment.Period & 'period = "delay"').fetch1('period_start')

no_stim_cond = (psth.TrialCondition
& {'trial_condition_name':
Expand All @@ -181,24 +209,32 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
& probe_insertion).fetch('duration'))
stim_dur = _extract_one_stim_dur(stim_durs)

units = ephys.Unit & probe_insertion & 'unit_quality != "all"'
units = ephys.Unit & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"'

metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) # TODO: account for dv_location
_, cue_onset = _get_trial_event_times(['delay'], units, 'all_noearlylick_both_alm_nostim')
cue_onset = cue_onset[0]

# XXX: could be done with 1x fetch+join
for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):

x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

nostim_psth, nostim_edge = (
psth.UnitPsth & {**unit, **no_stim_cond}).fetch1('unit_psth')
# obtain unit psth per trial, for all nostim and bistim trials
nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(no_stim_cond['trial_condition_name'])
bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(bi_stim_cond['trial_condition_name'])

bistim_psth, bistim_edge = (
psth.UnitPsth & {**unit, **bi_stim_cond}).fetch1('unit_psth')
nostim_psths, nostim_edge = psth.compute_unit_psth(unit, nostim_trials.fetch('KEY'), per_trial=True)
bistim_psths, bistim_edge = psth.compute_unit_psth(unit, bistim_trials.fetch('KEY'), per_trial=True)

# compute the firing rate difference between contra vs. ipsi within the stimulation duration
ctrl_frate = nostim_psth[np.logical_and(nostim_edge[1:] >= cue_onset, nostim_edge[1:] <= cue_onset + stim_dur)]
stim_frate = bistim_psth[np.logical_and(bistim_edge[1:] >= cue_onset, bistim_edge[1:] <= cue_onset + stim_dur)]
ctrl_frate = np.array([nostim_psth[np.logical_and(nostim_edge >= cue_onset,
nostim_edge <= cue_onset + stim_dur)].mean()
for nostim_psth in nostim_psths])
stim_frate = np.array([bistim_psth[np.logical_and(bistim_edge >= cue_onset,
bistim_edge <= cue_onset + stim_dur)].mean()
for bistim_psth in bistim_psths])

frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean()
frate_change = abs(frate_change) if frate_change < 0 else 0.0001
Expand Down Expand Up @@ -230,9 +266,8 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
def plot_stacked_contra_ipsi_psth(units, axs=None):
units = units.proj()

period_starts = (experiment.Period
& 'period in ("sample", "delay", "response")').fetch(
'period_start')
# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

hemi = _get_units_hemisphere(units)

Expand Down Expand Up @@ -285,9 +320,8 @@ def plot_stacked_contra_ipsi_psth(units, axs=None):
def plot_avg_contra_ipsi_psth(units, axs=None):
units = units.proj()

period_starts = (experiment.Period
& 'period in ("sample", "delay", "response")').fetch(
'period_start')
# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

hemi = _get_units_hemisphere(units)

Expand Down Expand Up @@ -349,10 +383,6 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):

hemi = _get_units_hemisphere(units)

period_starts = (experiment.Period
& 'period in ("sample", "delay", "response")').fetch(
'period_start')

psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
& {'trial_condition_name':
'all_noearlylick_both_alm_stim_left'}).fetch('unit_psth')
Expand All @@ -369,6 +399,9 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
& {'trial_condition_name':
'all_noearlylick_both_alm_nostim_right'}).fetch('unit_psth')

# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

# get photostim duration
stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
* psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
Expand Down Expand Up @@ -402,9 +435,8 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
ax.set_ylim((0, ymax))

# add shaded bar for photostim
delay = (experiment.Period # TODO: use from period_starts
& 'period = "delay"').fetch1('period_start')
axs[1].axvspan(delay, delay + stim_dur, alpha=0.3, color='royalblue')
stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')

return fig

Expand All @@ -423,10 +455,6 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):

hemi = _get_units_hemisphere(units)

period_starts = (experiment.Period
& 'period in ("sample", "delay", "response")').fetch(
'period_start')

# no photostim:
psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0]
psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0]
Expand All @@ -444,6 +472,9 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
& {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth')

# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

# get photostim duration
stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0]
stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
Expand Down Expand Up @@ -474,7 +505,7 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
ax.set_xlim([_plt_xmin, _plt_xmax])

# add shaded bar for photostim
stim_time = (experiment.Period & 'period = "delay"').fetch1('period_start')
stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')

return fig
Expand All @@ -484,7 +515,8 @@ def plot_coding_direction(units, time_period=None, axs=None):
_, proj_contra_trial, proj_ipsi_trial, time_stamps, _ = psth.compute_CD_projected_psth(
units.fetch('KEY'), time_period=time_period)

period_starts = (experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start')
# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick')

fig = None
if axs is None:
Expand Down Expand Up @@ -515,7 +547,8 @@ def plot_paired_coding_direction(unit_g1, unit_g2, labels=None, time_period=None
_, proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, unit_g2_hemi = psth.compute_CD_projected_psth(
unit_g2.fetch('KEY'), time_period=time_period)

period_starts = (experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start')
# get event start times: sample, delay, response
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_g1, 'good_noearlylick')

if labels:
assert len(labels) == 2
Expand Down
Loading

0 comments on commit 5736578

Please sign in to comment.