Skip to content

Commit

Permalink
Add l1 decay term to update calculation (#84)
Browse files Browse the repository at this point in the history
* Update hebbianSynapse.py

* update main

update main at the end

* Update hebbianSynapse.py

add regularization argument and w_decay is deprecated.

* Update hebbianSynapse.py

add elastic_net

* Update hebbianSynapse.py

* Update hebbianSynapse.py
  • Loading branch information
Faezehabibi authored Dec 9, 2024
1 parent 2295ba5 commit eeb057a
Showing 1 changed file with 47 additions and 17 deletions.
64 changes: 47 additions & 17 deletions ngclearn/components/synapses/hebbian/hebbianSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from ngclearn import resolver, Component, Compartment
from ngclearn.components.synapses import DenseSynapse
from ngclearn.utils import tensorstats
from ngcsimlib.deprecators import deprecate_args

@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8])
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay=0.,
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
prior_type=None, prior_lmbda=0.,
pre_wght=1., post_wght=1.):
"""
Compute a tensor of adjustments to be applied to a synaptic value matrix.
Expand All @@ -25,7 +27,9 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
signVal: multiplicative factor to modulate final update by (good for
flipping the signs of a computed synaptic change matrix)
w_decay: synaptic decay factor to apply to this update
prior_type: prior type or name (Default: None)
prior_lmbda: prior parameter (Default: 0.0)
pre_wght: pre-synaptic weighting term (Default: 1.)
Expand All @@ -38,10 +42,21 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
_post = post * post_wght
dW = jnp.matmul(_pre.T, _post)
db = jnp.sum(_post, axis=0, keepdims=True)
dW_reg = 0.

if w_bound > 0.:
dW = dW * (w_bound - jnp.abs(W))
if w_decay > 0.:
dW = dW - W * w_decay

if prior_type == "l2" or prior_type == "ridge":
dW_reg = W
if prior_type == "l1" or prior_type == "lasso":
dW_reg = jnp.sign(W)
if prior_type == "l1l2" or prior_type == "elastic_net":
l1_ratio = prior_lmbda[1]
prior_lmbda = prior_lmbda[0]
dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2

dW = dW + prior_lmbda * dW_reg
return dW * signVal, db * signVal

@partial(jit, static_argnums=[1,2])
Expand All @@ -68,6 +83,7 @@ def _enforce_constraints(W, w_bound, is_nonnegative=True):
_W = jnp.clip(_W, -w_bound, w_bound)
return _W


class HebbianSynapse(DenseSynapse):
"""
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
Expand Down Expand Up @@ -107,9 +123,17 @@ class HebbianSynapse(DenseSynapse):
is_nonnegative: enforce that synaptic efficacies are always non-negative
after each synaptic update (if False, no constraint will be applied)
w_decay: degree to which (L2) synaptic weight decay is applied to the
computed Hebbian adjustment (Default: 0); note that decay is not
applied to any configured biases
prior: a kernel to drive prior of this synaptic cable's values;
typically a tuple with 1st element as a string calling the name of
prior to use and 2nd element as a floating point number
calling the prior parameter lambda (Default: (None, 0.))
currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
usage guide:
prior = ('l1', 0.01) or prior = ('lasso', lmbda)
prior = ('l2', 0.01) or prior = ('ridge', lmbda)
prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
sign_value: multiplicative factor to apply to final synaptic update before
it is applied to synapses; this is useful if gradient descent style
Expand Down Expand Up @@ -137,18 +161,24 @@ class HebbianSynapse(DenseSynapse):
"""

# Define Functions
@deprecate_args(_rebind=False, w_decay='prior')
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1.,
w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1.,
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
resist_scale=1., batch_size=1, **kwargs):
super().__init__(name, shape, weight_init, bias_init, resist_scale,
p_conn, batch_size=batch_size, **kwargs)

if w_decay > 0.:
prior = ('l2', w_decay)

prior_type, prior_lmbda = prior
## synaptic plasticity properties and characteristics
self.shape = shape
self.Rscale = resist_scale
self.prior_type = prior_type
self.prior_lmbda = prior_lmbda
self.w_bound = w_bound
self.w_decay = w_decay ## synaptic decay
self.pre_wght = pre_wght
self.post_wght = post_wght
self.eta = eta
Expand All @@ -172,21 +202,21 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
if bias_init else [self.weights.value]))

@staticmethod
def _compute_update(w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
post_wght, pre, post, weights):
## calculate synaptic update values
dW, db = _calc_update(
pre, post, weights, w_bound, is_nonnegative=is_nonnegative,
signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght,
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
post_wght=post_wght)
return dW, db

@staticmethod
def _evolve(opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
post_wght, bias_init, pre, post, weights, biases, opt_params):
## calculate synaptic update values
dWeights, dBiases = HebbianSynapse._compute_update(
w_bound, is_nonnegative, sign_value, w_decay, pre_wght, post_wght,
w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght,
pre, post, weights
)
## conduct a step of optimization - get newly evolved synaptic weight value matrix
Expand Down Expand Up @@ -264,13 +294,13 @@ def help(cls): ## component help function
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
"post_wght": "Post-synaptic weighting coefficient (q_post)",
"w_bound": "Soft synaptic bound applied to synapses post-update",
"w_decay": "Synaptic decay term",
"prior": "prior name and value for synaptic updating prior",
"optim_type": "Choice of optimizer to adjust synaptic weights"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda",
"hyperparameters": hyperparams}
return info

Expand All @@ -292,5 +322,5 @@ def __repr__(self):
from ngcsimlib.context import Context
with Context("Bar") as bar:
Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam',
sign_value=-1.0, bias_init=("constant", 0., 0.))
sign_value=-1.0, prior=("l1l2", 0.001))
print(Wab)

0 comments on commit eeb057a

Please sign in to comment.