Skip to content

Commit

Permalink
Merge pull request #7 from ixcat/group-psth
Browse files Browse the repository at this point in the history
UnitGroupPsth implementation - initial implementation; data correction / updates to follow
  • Loading branch information
Thinh Nguyen authored Jun 13, 2019
2 parents eb7bf95 + 54e6c9b commit bc5a563
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 300 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
FROM datajoint/jupyter:latest

RUN pip install datajoint==0.12.dev4

RUN apt update && apt -y install mysql-client-5.7 netcat

RUN pip install globus_sdk
Expand Down
238 changes: 238 additions & 0 deletions notebook/group_psth.ipynb

Large diffs are not rendered by default.

94 changes: 80 additions & 14 deletions notebook/unit_psth_low_level_plot.ipynb

Large diffs are not rendered by default.

117 changes: 70 additions & 47 deletions notebook/unit_psth_quick_plot.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pipeline/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# -*- mode: python -*-

from .util import show_source

from .unit_psth import unit_psth
from .unit_psth import unit_psth_ll

from .group_psth import group_psth
from .group_psth import group_psth_ll

__all__ = [show_source,
unit_psth, unit_psth_ll,
group_psth, group_psth_ll]
53 changes: 53 additions & 0 deletions pipeline/plot/group_psth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

import numpy as np

import matplotlib.pyplot as plt

from pipeline import psth


def movmean(data, nsamp=5): # TODO: moveout
''' moving average over n samples '''
ret = np.cumsum(data, dtype=float)
ret[nsamp:] = ret[nsamp:] - ret[:-nsamp]
return ret[nsamp - 1:] / nsamp


def group_psth_ll(psth_a, psth_b, invert=False):
plt_xmin, plt_xmax = -3, 3

assert len(psth_a) == len(psth_b)
nunits = len(psth_a)
aspect = 2 / nunits
extent = [plt_xmin, plt_xmax, 0, nunits]

a_data = np.array([r[0] for r in psth_a['unit_psth']])
b_data = np.array([r[0] for r in psth_b['unit_psth']])

# scale per-unit PSTHS's
a_data = np.array([movmean(i * (1 / np.max(i))) for i in a_data])
b_data = np.array([movmean(i * (1 / np.max(i))) for i in b_data])

if invert:
result = (a_data - b_data) * -1
else:
result = a_data - b_data

ax = plt.subplot(111)

# ax.set_axis_off()
ax.set_xlim([plt_xmin, plt_xmax])
ax.axvline(0, 0, 1, ls='--', color='k')
ax.axvline(-1.2, 0, 1, ls='--', color='k')
ax.axvline(-2.4, 0, 1, ls='--', color='k')

plt.imshow(result, cmap=plt.cm.bwr, aspect=aspect, extent=extent)


def group_psth(group_condition_key):

# XXX: currently raises NotImplementedError;
# see group_psth_rework.ipynb for latest status
unit_psths = psth.UnitGroupPsth.get(group_condition_key)

group_psth_ll(unit_psths[:]['unit_psth'])
88 changes: 48 additions & 40 deletions pipeline/plot/unit_psth.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,94 @@
# -*- mode: python -*-


import numpy as np
import scipy as sp
import datajoint as dj

import matplotlib.pyplot as plt

from pipeline import experiment
from pipeline import ephys
from pipeline import psth


def unit_psth_ll(ipsi_hit, contra_hit, ipsi_err, contra_err):
max_trial_off = 500
binSize=0.04
plt_xmin=-3
plt_xmax=3
plt_ymin=0
plt_ymax=None # dynamic per unit

plt_ymax = np.max([contra_hit['psth'][0], ipsi_hit['psth'][0], contra_err['psth'][0], ipsi_err['psth'][0]])


plt_xmin = -3
plt_xmax = 3
plt_ymin = 0
plt_ymax = None # dynamic per unit

plt_ymax = np.max([contra_hit['psth'][0],
ipsi_hit['psth'][0],
contra_err['psth'][0],
ipsi_err['psth'][0]])

plt.figure()

# raster plot
ax=plt.subplot(411)
plt.plot(contra_hit['raster'][0], contra_hit['raster'][1] + max_trial_off, 'b.', markersize=1)
ax = plt.subplot(411)
plt.plot(contra_hit['raster'][0], contra_hit['raster'][1] + max_trial_off,
'b.', markersize=1)
plt.plot(ipsi_hit['raster'][0], ipsi_hit['raster'][1], 'r.', markersize=1)
ax.set_axis_off()
ax.set_xlim([plt_xmin, plt_xmax])
ax.axvline(0,0,1, ls='--')
ax.axvline(-1.2,0,1, ls='--')
ax.axvline(-2.4,0,1, ls='--')
ax.axvline(0, 0, 1, ls='--')
ax.axvline(-1.2, 0, 1, ls='--')
ax.axvline(-2.4, 0, 1, ls='--')

# histogram of hits
ax = plt.subplot(412)
plt.plot(contra_hit['psth'][1][1:], contra_hit['psth'][0], 'b')
plt.plot(ipsi_hit['psth'][1][1:], ipsi_hit['psth'][0], 'r')

plt.ylabel('spikes/s')
ax.spines["top"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlim([plt_xmin, plt_xmax])
ax.set_ylim([plt_ymin, plt_ymax])
ax.set_xticklabels([])
ax.axvline(0,0,1, ls='--')
ax.axvline(-1.2,0,1, ls='--')
ax.axvline(-2.4,0,1, ls='--')
ax.axvline(0, 0, 1, ls='--')
ax.axvline(-1.2, 0, 1, ls='--')
ax.axvline(-2.4, 0, 1, ls='--')
plt.title('Correct trials')

# histogram of errors
ax = plt.subplot(413)
plt.plot(contra_err['psth'][1][1:], contra_err['psth'][0], 'b')
plt.plot(ipsi_err['psth'][1][1:], ipsi_err['psth'][0], 'r')
ax.spines["top"].set_visible(False)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlim([plt_xmin, plt_xmax])
ax.set_ylim([plt_ymin, plt_ymax])
ax.axvline(0,0,1, ls='--')
ax.axvline(-1.2,0,1, ls='--')
ax.axvline(-2.4,0,1, ls='--')
ax.axvline(0, 0, 1, ls='--')
ax.axvline(-1.2, 0, 1, ls='--')
ax.axvline(-2.4, 0, 1, ls='--')

plt.title('Error trials')
plt.xlabel('Time to go cue (s)')
plt.show()


def unit_psth(unit_key):

ipsi_hit_cond_key = (psth.Condition() & {'condition_desc': 'audio delay ipsi hit'}).fetch1('KEY')
contra_hit_cond_key = (psth.Condition() & {'condition_desc': 'audio delay contra hit'}).fetch1('KEY')

ipsi_err_cond_key = (psth.Condition() & {'condition_desc': 'audio delay ipsi error'}).fetch1('KEY')
contra_err_cond_key = (psth.Condition() & {'condition_desc': 'audio delay contra error'}).fetch1('KEY')

ipsi_hit_cond_key = (
psth.Condition() & {'condition_desc': 'audio delay ipsi hit'}
).fetch1('KEY')

contra_hit_cond_key = (
psth.Condition() & {'condition_desc': 'audio delay contra hit'}
).fetch1('KEY')

ipsi_err_cond_key = (
psth.Condition() & {'condition_desc': 'audio delay ipsi error'}
).fetch1('KEY')

contra_err_cond_key = (
psth.Condition() & {'condition_desc': 'audio delay contra error'}
).fetch1('KEY')

ipsi_hit_unit_psth = psth.UnitPsth.get(ipsi_hit_cond_key, unit_key)
contra_hit_unit_psth = psth.UnitPsth.get(contra_hit_cond_key, unit_key)

ipsi_err_unit_psth = psth.UnitPsth.get(ipsi_hit_cond_key, unit_key)
ipsi_err_unit_psth = psth.UnitPsth.get(ipsi_err_cond_key, unit_key)
contra_err_unit_psth = psth.UnitPsth.get(contra_err_cond_key, unit_key)

unit_psth_ll(ipsi_hit_unit_psth, contra_hit_unit_psth, ipsi_err_unit_psth, contra_err_unit_psth)

unit_psth_ll(ipsi_hit_unit_psth, contra_hit_unit_psth,
ipsi_err_unit_psth, contra_err_unit_psth)
Loading

0 comments on commit bc5a563

Please sign in to comment.