Skip to content

Commit da2f24e

Browse files
authored
Add attribute 'lr' (#90)
* Update elastic_net.py * add lr as attribute to lasso.py * add lr as attribute to ridge.py * refactor w_bound=0. for weights elastic_net.py deactivated w_bound for weights elastic_net.py * Update lasso.py * deactivated w_bound for weights ridge.py
1 parent 9a3ce0e commit da2f24e

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

ngclearn/modules/regression/elastic_net.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
7373
self.weight_fill = weight_fill
7474
self.threshold = threshold
7575
self.name = name
76+
self.lr = lr
7677
feature_dim = dict_dim
7778

7879
with Context(self.name) as self.circuit:
79-
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
80+
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
8081
sign_value=-1, weight_init=dist.constant(weight_fill),
81-
prior=('elastic_net', (lmbda, l1_ratio)), optim_type=optim_type, key=subkeys[0])
82+
prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0.,
83+
optim_type=optim_type, key=subkeys[0])
8284
self.err = GaussianErrorCell("err", n_units=sys_dim)
8385

8486
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

ngclearn/modules/regression/lasso.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
7676
self.weight_fill = weight_fill
7777
self.threshold = threshold
7878
self.name = name
79+
self.lr = lr
7980
feature_dim = dict_dim
8081

8182
with Context(self.name) as self.circuit:
82-
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
83+
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
8384
sign_value=-1, weight_init=dist.constant(weight_fill),
84-
prior=('lasso', lasso_lmbda),
85-
optim_type=optim_type, key=subkeys[0])
85+
prior=('lasso', lasso_lmbda), w_bound=0.,
86+
optim_type=optim_type, key=subkeys[0])
8687
self.err = GaussianErrorCell("err", n_units=sys_dim)
8788
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8889
self.W.batch_size = batch_size

ngclearn/modules/regression/ridge.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
7171
self.weight_fill = weight_fill
7272
self.threshold = threshold
7373
self.name = name
74+
self.lr = lr
7475
feature_dim = dict_dim
7576

7677
with Context(self.name) as self.circuit:
77-
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=lr,
78+
self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
7879
sign_value=-1, weight_init=dist.constant(weight_fill),
79-
prior=('ridge', ridge_lmbda), optim_type=optim_type, key=subkeys[0])
80+
prior=('ridge', ridge_lmbda), w_bound=0.,
81+
optim_type=optim_type, key=subkeys[0])
8082
self.err = GaussianErrorCell("err", n_units=sys_dim)
8183

8284
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)