From 23a54f60fe9a39910e7f65e763debb7c72f60ffb Mon Sep 17 00:00:00 2001 From: ago109 Date: Wed, 24 Jul 2024 16:00:28 -0400 Subject: [PATCH] integrated if-cell, cleaned up lif and inits --- ngclearn/components/__init__.py | 1 + ngclearn/components/jaxComponent.py | 1 - ngclearn/components/neurons/__init__.py | 1 + ngclearn/components/neurons/spiking/IFCell.py | 306 ++++++++++++++++++ .../components/neurons/spiking/LIFCell.py | 47 ++- .../components/neurons/spiking/__init__.py | 1 + 6 files changed, 345 insertions(+), 12 deletions(-) create mode 100755 ngclearn/components/neurons/spiking/IFCell.py diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index d9534871c..005fbacc5 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -6,6 +6,7 @@ from .neurons.graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types from .neurons.spiking.sLIFCell import SLIFCell +from .neurons.spiking.IFCell import IFCell from .neurons.spiking.LIFCell import LIFCell from .neurons.spiking.WTASCell import WTASCell from .neurons.spiking.quadLIFCell import QuadLIFCell diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 8286c6c02..f07309fe5 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -21,4 +21,3 @@ def __init__(self, name, key=None, directory=None, **kwargs): self.directory = directory self.key = Compartment( random.PRNGKey(time.time_ns()) if key is None else key) - diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index 900a58cec..42a4a971c 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -5,6 +5,7 @@ from .graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types from .spiking.sLIFCell import SLIFCell +from .spiking.IFCell import IFCell from .spiking.LIFCell import LIFCell from .spiking.WTASCell import WTASCell from .spiking.quadLIFCell import QuadLIFCell diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py new file mode 100755 index 000000000..2c9acf529 --- /dev/null +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -0,0 +1,306 @@ +from jax import numpy as jnp, random, jit, nn +from ngclearn.utils import tensorstats +from ngcsimlib.deprecators import deprecate_args +from ngclearn import resolver, Component, Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ + step_euler, step_rk2 +from ngclearn.utils.surrogate_fx import (arctan_estimator, + triangular_estimator, + straight_through_estimator) + +@jit +def _update_times(t, s, tols): + """ + Updates time-of-last-spike (tols) variable. + + Args: + t: current time (a scalar/int value) + + s: binary spike vector + + tols: current time-of-last-spike variable + + Returns: + updated tols variable + """ + _tols = (1. - s) * tols + (s * t) + return _tols + +@jit +def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics + mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask + ## update voltage / membrane potential + dv_dt = (j * mask) ## integration only involves electrical current + dv_dt = dv_dt * (1./tau_m) + return dv_dt + +def _dfv(t, v, params): ## voltage dynamics wrapper + j, rfr, tau_m, refract_T = params + dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T) + return dv_dt + +def _run_cell(dt, j, v, v_thr, rfr, tau_m, v_rest, v_reset, refract_T, integType=0): + ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics + ## update voltage / membrane potential + v_params = (j, rfr, tau_m, refract_T) + if integType == 1: + _, _v = step_rk2(0., v, _dfv, dt, v_params) + else: + _, _v = step_euler(0., v, _dfv, dt, v_params) + ## obtain action potentials/spikes + s = (_v > v_thr).astype(jnp.float32) + ## update refractory variables + _rfr = (rfr + dt) * (1. - s) + ## perform hyper-polarization of neuronal cells + _v = _v * (1. - s) + s * v_reset + return _v, s, _rfr + +class IFCell(JaxComponent): ## integrate-and-fire cell + """ + A spiking cell based on integrate-and-fire (IF) neuronal dynamics. + + The specific differential equation that characterizes this cell + is (for adjusting v, given current j, over time) is: + + | tau_m * dv/dt = (v_rest - v) + j * R + | where R is the membrane resistance and v_rest is the resting potential + | also, if a spike occurs, v is set to v_reset + + | --- Cell Input Compartments: --- + | j - electrical current input (takes in external signals) + | --- Cell State Compartments: --- + | v - membrane potential/voltage state + | rfr - (relative) refractory variable state + | key - JAX PRNG key + | --- Cell Output Compartments: --- + | s - emitted binary spikes/action potentials + | s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s) + | tols - time-of-last-spike + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + tau_m: membrane time constant + + resist_m: membrane resistance value (default: 1) + + thr: base value for adaptive thresholds that govern short-term + plasticity (in milliVolts, or mV; default: -52. mV) + + v_rest: membrane resting potential (in mV; default: -65 mV) + + v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, + a neuronal cell's membrane potential will be set to this value; + (default: -60 mV) + + refract_time: relative refractory period time (ms; default: 0 ms) + + integration_type: type of integration to use for this cell's dynamics; + current supported forms include "euler" (Euler/RK-1 integration) + and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") + + :Note: setting the integration type to the midpoint method will + increase the accuray of the estimate of the cell's evolution + at an increase in computational cost (and simulation time) + + surrgoate_type: type of surrogate function to use for approximating a + partial derivative of this cell's spikes w.r.t. its voltage/current + (default: "straight_through") + + :Note: surrogate options available include: "straight_through" + (straight-through estimator), "triangular" (triangular estimator), + and "arctan" (arc-tangent estimator) + + lower_clamp_voltage: if True, this will ensure voltage never is below + the value of `v_rest` (default: True) + """ + + @deprecate_args(thr_jitter=None) + def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., + v_reset=-60., refract_time=0., integration_type="euler", + surrgoate_type="straight_through", lower_clamp_voltage=True, + **kwargs): + super().__init__(name, **kwargs) + + ## Integration properties + self.integrationType = integration_type + self.intgFlag = get_integrator_code(self.integrationType) + + ## membrane parameter setup (affects ODE integration) + self.tau_m = tau_m ## membrane time constant + self.resist_m = resist_m ## resistance value + + self.v_rest = v_rest #-65. # mV + self.v_reset = v_reset # -60. # -65. # mV (milli-volts) + ## basic asserts to prevent neuronal dynamics breaking... + assert self.resist_m > 0. + self.refract_T = refract_time #5. # 2. ## refractory period # ms + self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV + self.lower_clamp_voltage = lower_clamp_voltage + + ## Layer Size Setup + self.batch_size = 1 + self.n_units = n_units + + ## set up surrogate function for spike emission + if surrgoate_type == "arctan": + self.spike_fx, self.d_spike_fx = arctan_estimator() + elif surrgoate_type == "triangular": + self.spike_fx, self.d_spike_fx = triangular_estimator() + else: ## default: straight_through + self.spike_fx, self.d_spike_fx = straight_through_estimator() + + + ## Compartment setup + restVals = jnp.zeros((self.batch_size, self.n_units)) + self.j = Compartment(restVals, display_name="Current", units="mA") + self.v = Compartment(restVals + self.v_rest, + display_name="Voltage", units="mV") + self.s = Compartment(restVals, display_name="Spikes") + self.rfr = Compartment(restVals + self.refract_T, + display_name="Refractory Time Period", units="ms") + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", + units="ms") ## time-of-last-spike + self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") + + @staticmethod + def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, + thr, lower_clamp_voltage, intgFlag, d_spike_fx, key, + j, v, rfr, tols): + ## run one integration step for neuronal dynamics + j = j * resist_m + v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset, + refract_T, intgFlag) + surrogate = d_spike_fx(v, thr) + ## update tols + tols = _update_times(t, s, tols) + if lower_clamp_voltage: ## ensure voltage never < v_rest + v = jnp.maximum(v, v_rest) + return v, s, rfr, tols, key, surrogate + + @resolver(_advance_state) + def advance_state(self, v, s, rfr, tols, key, surrogate): + self.v.set(v) + self.s.set(s) + self.rfr.set(rfr) + self.tols.set(tols) + self.key.set(key) + self.surrogate.set(surrogate) + + @staticmethod + def _reset(batch_size, n_units, v_rest, refract_T): + restVals = jnp.zeros((batch_size, n_units)) + j = restVals #+ 0 + v = restVals + v_rest + s = restVals #+ 0 + rfr = restVals + refract_T + tols = restVals #+ 0 + surrogate = restVals + 1. + return j, v, s, rfr, tols, surrogate + + @resolver(_reset) + def reset(self, j, v, s, rfr, tols, surrogate): + self.j.set(j) + self.v.set(v) + self.s.set(s) + self.rfr.set(rfr) + self.tols.set(tols) + self.surrogate.set(surrogate) + + def save(self, directory, **kwargs): + ## do a protected save of constants, depending on whether they are floats or arrays + tau_m = (self.tau_m if isinstance(self.tau_m, float) + else jnp.ones([[self.tau_m]])) + thr = (self.thr if isinstance(self.thr, float) + else jnp.ones([[self.thr]])) + v_rest = (self.v_rest if isinstance(self.v_rest, float) + else jnp.ones([[self.v_rest]])) + v_reset = (self.v_reset if isinstance(self.v_reset, float) + else jnp.ones([[self.v_reset]])) + v_decay = (self.v_decay if isinstance(self.v_decay, float) + else jnp.ones([[self.v_decay]])) + resist_m = (self.resist_m if isinstance(self.resist_m, float) + else jnp.ones([[self.resist_m]])) + tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) + else jnp.ones([[self.tau_theta]])) + theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) + else jnp.ones([[self.theta_plus]])) + + file_name = directory + "/" + self.name + ".npz" + jnp.savez(file_name, + tau_m=tau_m, thr=thr, v_rest=v_rest, + v_reset=v_reset, v_decay=v_decay, + resist_m=resist_m, tau_theta=tau_theta, + theta_plus=theta_plus, + key=self.key.value) + + def load(self, directory, seeded=False, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + ## constants loaded in + self.tau_m = data['tau_m'] + self.thr = data['thr'] + self.v_rest = data['v_rest'] + self.v_reset = data['v_reset'] + self.v_decay = data['v_decay'] + self.resist_m = data['resist_m'] + self.tau_theta = data['tau_theta'] + self.theta_plus = data['theta_plus'] + + if seeded: + self.key.set(data['key']) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "IFCell - evolves neurons according to integrate-" + "and-fire spiking dynamics." + } + compartment_props = { + "inputs": + {"j": "External input electrical current"}, + "states": + {"v": "Membrane potential/voltage at time t", + "rfr": "Current state of (relative) refractory variable", + "thr": "Current state of voltage threshold at time t", + "key": "JAX PRNG key"}, + "outputs": + {"s": "Emitted spikes/pulses at time t", + "tols": "Time-of-last-spike"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "tau_m": "Cell membrane time constant", + "resist_m": "Membrane resistance value", + "thr": "Base voltage threshold value", + "v_rest": "Resting membrane potential value", + "v_reset": "Reset membrane potential value", + "refract_time": "Length of relative refractory period (ms)", + "integration_type": "Type of numerical integration to use for the cell dynamics", + "surrgoate_type": "Type of surrogate function to use approximate " + "derivative of spike w.r.t. voltage/current", + "lower_bound_clamp": "Should voltage be lower bounded to be never be below `v_rest`" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m", + "hyperparameters": hyperparams} + return info + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines + diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 44c2474d5..ead51f020 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -116,12 +116,13 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell resist_m: membrane resistance value (Default: 1) thr: base value for adaptive thresholds that govern short-term - plasticity (in milliVolts, or mV) + plasticity (in milliVolts, or mV; default: -52. mV) - v_rest: membrane resting potential (in mV) + v_rest: membrane resting potential (in mV; default: -65 mV) v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, - a neuronal cell's membrane potential will be set to this value + a neuronal cell's membrane potential will be set to this value; + (default: -60 mV) v_decay: decay factor applied to voltage leak (Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics @@ -131,7 +132,7 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell theta_plus: physical increment to be applied to any threshold value if a spike was emitted - refract_time: relative refractory period time (ms; Default: 1 ms) + refract_time: relative refractory period time (ms; Default: 5 ms) one_spike: if True, a single-spike constraint will be enforced for every time step of neuronal dynamics simulated, i.e., at most, only @@ -146,13 +147,26 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell :Note: setting the integration type to the midpoint method will increase the accuray of the estimate of the cell's evolution at an increase in computational cost (and simulation time) + + surrgoate_type: type of surrogate function to use for approximating a + partial derivative of this cell's spikes w.r.t. its voltage/current + (default: "straight_through") + + :Note: surrogate options available include: "straight_through" + (straight-through estimator), "triangular" (triangular estimator), + "arctan" (arc-tangent estimator), and "secant_lif" (the + LIF-specialized secant estimator) + + lower_clamp_voltage: if True, this will ensure voltage never is below + the value of `v_rest` (default: True) """ @deprecate_args(thr_jitter=None) def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", - surrgoate_type="straight_through", **kwargs): + surrgoate_type="straight_through", lower_clamp_voltage=True, + **kwargs): super().__init__(name, **kwargs) ## Integration properties @@ -163,6 +177,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.tau_m = tau_m ## membrane time constant self.resist_m = resist_m ## resistance value self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step + self.lower_clamp_voltage = lower_clamp_voltage ## True ==> ensures voltage is never < v_rest self.v_rest = v_rest #-65. # mV self.v_reset = v_reset # -60. # -65. # mV (milli-volts) @@ -207,8 +222,8 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., @staticmethod def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, - thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx, - key, j, v, s, rfr, thr_theta, tols): + thr, tau_theta, theta_plus, one_spike, lower_clamp_voltage, + intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols): skey = None ## this is an empty dkey if single_spike mode turned off if one_spike: key, skey = random.split(key, 2) @@ -223,7 +238,9 @@ def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) ## update tols tols = _update_times(t, s, tols) - return jnp.maximum(v, v_rest), s, raw_spikes, rfr, thr_theta, tols, key, surrogate + if lower_clamp_voltage: ## ensure voltage never < v_rest + v = jnp.maximum(v, v_rest) + return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate @resolver(_advance_state) def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate): @@ -317,6 +334,7 @@ def help(cls): ## component help function {"v": "Membrane potential/voltage at time t", "rfr": "Current state of (relative) refractory variable", "thr": "Current state of voltage threshold at time t", + "thr_theta": "Current state of homeostatic adaptive threshold at time t", "key": "JAX PRNG key"}, "outputs": {"s": "Emitted spikes/pulses at time t", @@ -331,10 +349,17 @@ def help(cls): ## component help function "v_reset": "Reset membrane potential value", "v_decay": "Voltage leak/decay factor", "tau_theta": "Threshold/homoestatic increment time constant", - "theta_plus": "Amount to increment threshold by upon occurrence of spike", + "theta_plus": "Amount to increment threshold by upon occurrence " + "of spike", "refract_time": "Length of relative refractory period (ms)", - "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?", - "integration_type": "Type of numerical integration to use for the cell dynamics" + "one_spike": "Should only one spike be sampled/allowed to emit at " + "any given time step?", + "integration_type": "Type of numerical integration to use for the " + "cell dynamics", + "surrgoate_type": "Type of surrogate function to use approximate " + "derivative of spike w.r.t. voltage/current", + "lower_bound_clamp": "Should voltage be lower bounded to be never " + "be below `v_rest`" } info = {cls.__name__: properties, "compartments": compartment_props, diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py index cd9aa1811..2934eda9d 100644 --- a/ngclearn/components/neurons/spiking/__init__.py +++ b/ngclearn/components/neurons/spiking/__init__.py @@ -1,6 +1,7 @@ ## point to standard spiking cell component types from .sLIFCell import SLIFCell from .LIFCell import LIFCell +from .IFCell import IFCell from .WTASCell import WTASCell from .quadLIFCell import QuadLIFCell from .adExCell import AdExCell