Skip to content

Commit 23e8c84

Browse files
author
Alexander Ororbia
committed
cleaned up model_utils
1 parent 1fbbf93 commit 23e8c84

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

ngclearn/utils/model_utils.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,20 @@ def create_function(fun_name, args=None):
7272
Activation function creation routine.
7373
7474
Args:
75-
fun_name: string name of activation function to produce
76-
(Currently supports: "tanh", "relu", "lrelu", "identity")
75+
fun_name: string name of activation function to produce;
76+
Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6",
77+
"elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside",
78+
"identity"
7779
7880
Returns:
7981
function fx, first derivative of function (w.r.t. input) dfx
8082
"""
81-
fx = None
82-
dfx = None
83+
fx = None ## the function
84+
dfx = None ## the first derivative of function w.r.t. its input
8385
if fun_name == "tanh":
8486
fx = tanh
8587
dfx = d_tanh
86-
elif "kwta" in fun_name:
88+
elif "bkwta" in fun_name:
8789
fx = bkwta
8890
dfx = bkwta #d_identity
8991
elif fun_name == "sigmoid":
@@ -98,6 +100,15 @@ def create_function(fun_name, args=None):
98100
elif fun_name == "relu6":
99101
fx = relu6
100102
dfx = d_relu6
103+
elif fun_name == "elu":
104+
fx = elu
105+
dfx = d_elu
106+
elif fun_name == "silu":
107+
fx = silu
108+
dfx = d_silu
109+
elif fun_name == "gelu":
110+
fx = gelu
111+
dfx = d_gelu
101112
elif fun_name == "softplus":
102113
fx = softplus
103114
dfx = d_softplus
@@ -127,35 +138,35 @@ def bkwta(x, nWTA=5): #5 10 15 #K=50):
127138
return topK
128139

129140
@partial(jit, static_argnums=[2, 3, 4])
130-
def normalize_matrix(M, wnorm, order=1, axis=0, scale=1.):
141+
def normalize_matrix(data, wnorm, order=1, axis=0, scale=1.):
131142
"""
132143
Normalizes the values in matrix to have a particular norm across each vector span.
133144
134145
Args:
135-
M: (2D) matrix to normalize
146+
data: (2D) data matrix to normalize
136147
137-
wnorm: target norm for each
148+
wnorm: target norm for each row/column of data matrix
138149
139150
order: order of norm to use in normalization (Default: 1);
140151
note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm
141152
142153
axis: 0 (apply to column vectors), 1 (apply to row vectors)
143154
144-
scale: step modifier to produce the projected matrix
155+
scale: step modifier to produce the projected matrix (Unused)
145156
146157
Returns:
147158
a normalized value matrix
148159
"""
149160
if order == 2: ## denominator is L2 norm
150-
wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(M), axis=axis, keepdims=True)), 1e-8)
161+
wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(data), axis=axis, keepdims=True)), 1e-8)
151162
else: ## denominator is L1 norm
152-
wOrdSum = jnp.maximum(jnp.sum(jnp.abs(M), axis=axis, keepdims=True), 1e-8)
163+
wOrdSum = jnp.maximum(jnp.sum(jnp.abs(data), axis=axis, keepdims=True), 1e-8)
153164
m = (wOrdSum == 0.).astype(dtype=jnp.float32)
154165
wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1.
155-
_M = M * (wnorm/wOrdSum)
156-
#dM = ((wnorm/wOrdSum) - 1.) * M
157-
#_M = M + dM * scale
158-
return _M
166+
_data = data * (wnorm/wOrdSum)
167+
#d_data = ((wnorm/wOrdSum) - 1.) * data
168+
#_data = data + d_data * scale
169+
return _data
159170

160171
@jit
161172
def clamp_min(x, min_val):
@@ -529,7 +540,7 @@ def elu(x, alpha=1.):
529540
return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
530541

531542
@jit
532-
def elu(x, alpha=1.):
543+
def d_elu(x, alpha=1.):
533544
mask = (x >= 0.)
534545
return mask + (1. - mask) * (jnp.exp(x) * alpha)
535546

0 commit comments

Comments
 (0)