Skip to content

Commit

Permalink
cleaned up raf-cell
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Aug 9, 2024
1 parent ee50f33 commit 611e5b3
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions ngclearn/components/neurons/spiking/RAFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ngclearn import resolver, Component, Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils import tensorstats
from ngcsimlib.deprecators import deprecate_args
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2

Expand Down Expand Up @@ -37,7 +38,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper

@jit
def _dfw_internal(j, v, w, tau_w, omega, b): ## raw angular driver dynamics
# dx/dt = b x − omega y + I
# dx/dt = b x − omega y + I; I is scaled injected electrical current
dw_dt = w * b - v * omega + j
dw_dt = dw_dt * (1./tau_w)
return dw_dt
Expand All @@ -61,9 +62,10 @@ class RAFCell(JaxComponent):
The specific pair of differential equations that characterize this cell
are (for adjusting v and w, given current j, over time):
| tau_m * dv/dt = omega * w + v * b
| tau_w * dw/dt = w * b - v * omega + j
| tau_v * dv/dt = omega * w + v * b
| where omega is angular frequency (Hz) and b is exponential dampening factor
| Note: injected current j should generally be scaled by tau_w/dt
| --- Cell Input Compartments: ---
| j - electrical current input (takes in external signals)
Expand All @@ -84,27 +86,27 @@ class RAFCell(JaxComponent):
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant (Default: 15 ms)
tau_v: membrane/voltage time constant (Default: 1 ms)
resist_m: membrane resistance (Default: 1 mega-Ohm)
tau_w: angular driver variable time constant (Default: 400 ms)
tau_w: angular driver variable time constant (Default: 1 ms)
thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes) (Default: 5 mV)
of binary spikes) (Default: 1 mV)
omega: angular frequency (Default: 10)
b: oscillation dampening factor (Default: -1)
v_reset: membrane reset potential condition (Default: 0 mV)
v_reset: membrane potential reset condition (Default: 1 mV)
w_reset: reset condition for angular driver (Default: 0)
w_reset: reset condition for angular current driver (Default: 0)
v0: membrane potential initial condition (Default: 0 mV)
v0: membrane potential initial condition (Default: 1 mV)
w0: angular driver initial condition (Default: 0)
resist_v: membrane resistance (Default: 1 mega-Ohm)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
Expand All @@ -114,25 +116,24 @@ class RAFCell(JaxComponent):
at an increase in computational cost (and simulation time)
"""

# Define Functions
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
thr=5., omega=10., b=-1., v_reset=0., w_reset=0.,
v0=0., w0=0., integration_type="euler", batch_size=1, **kwargs):
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0.,
@deprecate_args(resist_m="resist_v", tau_m="tau_v")
def __init__(self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10.,
b=-1., v_reset=1., w_reset=0., v0=0., w0=0., resist_v=1.,
integration_type="euler", batch_size=1, **kwargs):
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
super().__init__(name, **kwargs)

## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)

## Cell properties
self.tau_m = tau_m
self.resist_m = resist_m
self.tau_v = tau_v
self.resist_v = resist_v
self.tau_w = tau_w
self.omega = omega ## angular frequency
self.b = b ## dampening factor
## note: the smaller b is, the faster the oscillation dampens to resting state values
#self.v_rest = v_rest
self.v_reset = v_reset
self.w_reset = w_reset
self.v0 = v0
Expand All @@ -153,24 +154,25 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
units="ms") ## time-of-last-spike

@staticmethod
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b,
def _advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
v_reset, w_reset, intgFlag, j, v, w, tols):
## continue with centered dynamics
j_ = j * resist_m
j_ = j * resist_v
if intgFlag == 1: ## RK-2/midpoint
w_params = (j_, v, tau_w, omega, b)
_, _w = step_rk2(0., w, _dfw, dt, w_params)
v_params = (j_, w, tau_m, omega, b)
v_params = (j_, w, tau_v, omega, b)
_, _v = step_rk2(0., v, _dfv, dt, v_params)
else: # integType == 0 (default -- Euler)
w_params = (j_, v, tau_w, omega, b)
_, _w = step_euler(0., w, _dfw, dt, w_params)
v_params = (j_, w, tau_m, omega, b)
v_params = (j_, w, tau_v, omega, b)
_, _v = step_euler(0., v, _dfv, dt, v_params)
s = _emit_spike(_v, thr)
## hyperpolarize/reset/snap variables
v = _v * (1. - s) + s * v_reset
w = _w * (1. - s) + s * w_reset
v = _v * (1. - s) + s * v_reset

tols = _update_times(t, s, tols)
return j, v, w, s, tols

Expand All @@ -183,11 +185,11 @@ def advance_state(self, j, v, w, s, tols):
self.tols.set(tols)

@staticmethod
def _reset(batch_size, n_units, v0, w0):
def _reset(batch_size, n_units, v_reset, w_reset):
restVals = jnp.zeros((batch_size, n_units))
j = restVals # None
v = restVals + v0
w = restVals + w0
v = restVals + v_reset
w = restVals + w_reset
s = restVals #+ 0
tols = restVals #+ 0
return j, v, w, s, tols
Expand All @@ -212,28 +214,28 @@ def help(cls): ## component help function
"key": "JAX PRNG key"},
"states":
{"v": "Membrane potential/voltage at time t",
"w": "Recovery variable at time t"},
"w": "Angular current driver variable at time t"},
"outputs":
{"s": "Emitted spikes/pulses at time t",
"tols": "Time-of-last-spike"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"batch_size": "Batch size dimension of this component",
"tau_m": "Cell membrane time constant",
"resist_m": "Membrane resistance value",
"tau_v": "Cell membrane time constant",
"tau_w": "Recovery variable time constant",
"v_reset": "Reset membrane potential value",
"w_reset": "Reset angular driver value",
"b": "Exponential dampening factor applied to oscillations",
"omega": "Angular frequency of neuronal progress per second (radians)",
"v0": "Initial condition for membrane potential/voltage",
"w0": "Initial condition for membrane angular driver variable",
"resist_v": "Membrane resistance value",
"integration_type": "Type of numerical integration to use for the cell dynamics"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "tau_m * dv/dt = omega * w + v * b; "
"dynamics": "tau_v * dv/dt = omega * w + v * b; "
"tau_w * dw/dt = w * b - v * omega + j",
"hyperparameters": hyperparams}
return info
Expand Down

0 comments on commit 611e5b3

Please sign in to comment.