Skip to content

Commit

Permalink
integrated if-cell, cleaned up lif and inits
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 24, 2024
1 parent c19d15e commit 23a54f6
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 12 deletions.
1 change: 1 addition & 0 deletions ngclearn/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .neurons.graded.rewardErrorCell import RewardErrorCell
## point to standard spiking cell component types
from .neurons.spiking.sLIFCell import SLIFCell
from .neurons.spiking.IFCell import IFCell
from .neurons.spiking.LIFCell import LIFCell
from .neurons.spiking.WTASCell import WTASCell
from .neurons.spiking.quadLIFCell import QuadLIFCell
Expand Down
1 change: 0 additions & 1 deletion ngclearn/components/jaxComponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ def __init__(self, name, key=None, directory=None, **kwargs):
self.directory = directory
self.key = Compartment(
random.PRNGKey(time.time_ns()) if key is None else key)

1 change: 1 addition & 0 deletions ngclearn/components/neurons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .graded.rewardErrorCell import RewardErrorCell
## point to standard spiking cell component types
from .spiking.sLIFCell import SLIFCell
from .spiking.IFCell import IFCell
from .spiking.LIFCell import LIFCell
from .spiking.WTASCell import WTASCell
from .spiking.quadLIFCell import QuadLIFCell
Expand Down
306 changes: 306 additions & 0 deletions ngclearn/components/neurons/spiking/IFCell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
from jax import numpy as jnp, random, jit, nn
from ngclearn.utils import tensorstats
from ngcsimlib.deprecators import deprecate_args
from ngclearn import resolver, Component, Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import (arctan_estimator,
triangular_estimator,
straight_through_estimator)

@jit
def _update_times(t, s, tols):
"""
Updates time-of-last-spike (tols) variable.
Args:
t: current time (a scalar/int value)
s: binary spike vector
tols: current time-of-last-spike variable
Returns:
updated tols variable
"""
_tols = (1. - s) * tols + (s * t)
return _tols

@jit
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
## update voltage / membrane potential
dv_dt = (j * mask) ## integration only involves electrical current
dv_dt = dv_dt * (1./tau_m)
return dv_dt

def _dfv(t, v, params): ## voltage dynamics wrapper
j, rfr, tau_m, refract_T = params
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
return dv_dt

def _run_cell(dt, j, v, v_thr, rfr, tau_m, v_rest, v_reset, refract_T, integType=0):
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
## update voltage / membrane potential
v_params = (j, rfr, tau_m, refract_T)
if integType == 1:
_, _v = step_rk2(0., v, _dfv, dt, v_params)
else:
_, _v = step_euler(0., v, _dfv, dt, v_params)
## obtain action potentials/spikes
s = (_v > v_thr).astype(jnp.float32)
## update refractory variables
_rfr = (rfr + dt) * (1. - s)
## perform hyper-polarization of neuronal cells
_v = _v * (1. - s) + s * v_reset
return _v, s, _rfr

class IFCell(JaxComponent): ## integrate-and-fire cell
"""
A spiking cell based on integrate-and-fire (IF) neuronal dynamics.
The specific differential equation that characterizes this cell
is (for adjusting v, given current j, over time) is:
| tau_m * dv/dt = (v_rest - v) + j * R
| where R is the membrane resistance and v_rest is the resting potential
| also, if a spike occurs, v is set to v_reset
| --- Cell Input Compartments: ---
| j - electrical current input (takes in external signals)
| --- Cell State Compartments: ---
| v - membrane potential/voltage state
| rfr - (relative) refractory variable state
| key - JAX PRNG key
| --- Cell Output Compartments: ---
| s - emitted binary spikes/action potentials
| s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s)
| tols - time-of-last-spike
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant
resist_m: membrane resistance value (default: 1)
thr: base value for adaptive thresholds that govern short-term
plasticity (in milliVolts, or mV; default: -52. mV)
v_rest: membrane resting potential (in mV; default: -65 mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value;
(default: -60 mV)
refract_time: relative refractory period time (ms; default: 0 ms)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
:Note: setting the integration type to the midpoint method will
increase the accuray of the estimate of the cell's evolution
at an increase in computational cost (and simulation time)
surrgoate_type: type of surrogate function to use for approximating a
partial derivative of this cell's spikes w.r.t. its voltage/current
(default: "straight_through")
:Note: surrogate options available include: "straight_through"
(straight-through estimator), "triangular" (triangular estimator),
and "arctan" (arc-tangent estimator)
lower_clamp_voltage: if True, this will ensure voltage never is below
the value of `v_rest` (default: True)
"""

@deprecate_args(thr_jitter=None)
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
v_reset=-60., refract_time=0., integration_type="euler",
surrgoate_type="straight_through", lower_clamp_voltage=True,
**kwargs):
super().__init__(name, **kwargs)

## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)

## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.resist_m = resist_m ## resistance value

self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
## basic asserts to prevent neuronal dynamics breaking...
assert self.resist_m > 0.
self.refract_T = refract_time #5. # 2. ## refractory period # ms
self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV
self.lower_clamp_voltage = lower_clamp_voltage

## Layer Size Setup
self.batch_size = 1
self.n_units = n_units

## set up surrogate function for spike emission
if surrgoate_type == "arctan":
self.spike_fx, self.d_spike_fx = arctan_estimator()
elif surrgoate_type == "triangular":
self.spike_fx, self.d_spike_fx = triangular_estimator()
else: ## default: straight_through
self.spike_fx, self.d_spike_fx = straight_through_estimator()


## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.j = Compartment(restVals, display_name="Current", units="mA")
self.v = Compartment(restVals + self.v_rest,
display_name="Voltage", units="mV")
self.s = Compartment(restVals, display_name="Spikes")
self.rfr = Compartment(restVals + self.refract_T,
display_name="Refractory Time Period", units="ms")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")

@staticmethod
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, refract_T,
thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
j, v, rfr, tols):
## run one integration step for neuronal dynamics
j = j * resist_m
v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset,
refract_T, intgFlag)
surrogate = d_spike_fx(v, thr)
## update tols
tols = _update_times(t, s, tols)
if lower_clamp_voltage: ## ensure voltage never < v_rest
v = jnp.maximum(v, v_rest)
return v, s, rfr, tols, key, surrogate

@resolver(_advance_state)
def advance_state(self, v, s, rfr, tols, key, surrogate):
self.v.set(v)
self.s.set(s)
self.rfr.set(rfr)
self.tols.set(tols)
self.key.set(key)
self.surrogate.set(surrogate)

@staticmethod
def _reset(batch_size, n_units, v_rest, refract_T):
restVals = jnp.zeros((batch_size, n_units))
j = restVals #+ 0
v = restVals + v_rest
s = restVals #+ 0
rfr = restVals + refract_T
tols = restVals #+ 0
surrogate = restVals + 1.
return j, v, s, rfr, tols, surrogate

@resolver(_reset)
def reset(self, j, v, s, rfr, tols, surrogate):
self.j.set(j)
self.v.set(v)
self.s.set(s)
self.rfr.set(rfr)
self.tols.set(tols)
self.surrogate.set(surrogate)

def save(self, directory, **kwargs):
## do a protected save of constants, depending on whether they are floats or arrays
tau_m = (self.tau_m if isinstance(self.tau_m, float)
else jnp.ones([[self.tau_m]]))
thr = (self.thr if isinstance(self.thr, float)
else jnp.ones([[self.thr]]))
v_rest = (self.v_rest if isinstance(self.v_rest, float)
else jnp.ones([[self.v_rest]]))
v_reset = (self.v_reset if isinstance(self.v_reset, float)
else jnp.ones([[self.v_reset]]))
v_decay = (self.v_decay if isinstance(self.v_decay, float)
else jnp.ones([[self.v_decay]]))
resist_m = (self.resist_m if isinstance(self.resist_m, float)
else jnp.ones([[self.resist_m]]))
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
else jnp.ones([[self.tau_theta]]))
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
else jnp.ones([[self.theta_plus]]))

file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name,
tau_m=tau_m, thr=thr, v_rest=v_rest,
v_reset=v_reset, v_decay=v_decay,
resist_m=resist_m, tau_theta=tau_theta,
theta_plus=theta_plus,
key=self.key.value)

def load(self, directory, seeded=False, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
## constants loaded in
self.tau_m = data['tau_m']
self.thr = data['thr']
self.v_rest = data['v_rest']
self.v_reset = data['v_reset']
self.v_decay = data['v_decay']
self.resist_m = data['resist_m']
self.tau_theta = data['tau_theta']
self.theta_plus = data['theta_plus']

if seeded:
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "IFCell - evolves neurons according to integrate-"
"and-fire spiking dynamics."
}
compartment_props = {
"inputs":
{"j": "External input electrical current"},
"states":
{"v": "Membrane potential/voltage at time t",
"rfr": "Current state of (relative) refractory variable",
"thr": "Current state of voltage threshold at time t",
"key": "JAX PRNG key"},
"outputs":
{"s": "Emitted spikes/pulses at time t",
"tols": "Time-of-last-spike"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"tau_m": "Cell membrane time constant",
"resist_m": "Membrane resistance value",
"thr": "Base voltage threshold value",
"v_rest": "Resting membrane potential value",
"v_reset": "Reset membrane potential value",
"refract_time": "Length of relative refractory period (ms)",
"integration_type": "Type of numerical integration to use for the cell dynamics",
"surrgoate_type": "Type of surrogate function to use approximate "
"derivative of spike w.r.t. voltage/current",
"lower_bound_clamp": "Should voltage be lower bounded to be never be below `v_rest`"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
for c in comps:
stats = tensorstats(getattr(self, c).value)
if stats is not None:
line = [f"{k}: {v}" for k, v in stats.items()]
line = ", ".join(line)
else:
line = "None"
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
return lines

Loading

0 comments on commit 23a54f6

Please sign in to comment.