@@ -72,18 +72,20 @@ def create_function(fun_name, args=None):
72
72
Activation function creation routine.
73
73
74
74
Args:
75
- fun_name: string name of activation function to produce
76
- (Currently supports: "tanh", "relu", "lrelu", "identity")
75
+ fun_name: string name of activation function to produce;
76
+ Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6",
77
+ "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside",
78
+ "identity"
77
79
78
80
Returns:
79
81
function fx, first derivative of function (w.r.t. input) dfx
80
82
"""
81
- fx = None
82
- dfx = None
83
+ fx = None ## the function
84
+ dfx = None ## the first derivative of function w.r.t. its input
83
85
if fun_name == "tanh" :
84
86
fx = tanh
85
87
dfx = d_tanh
86
- elif "kwta " in fun_name :
88
+ elif "bkwta " in fun_name :
87
89
fx = bkwta
88
90
dfx = bkwta #d_identity
89
91
elif fun_name == "sigmoid" :
@@ -98,6 +100,15 @@ def create_function(fun_name, args=None):
98
100
elif fun_name == "relu6" :
99
101
fx = relu6
100
102
dfx = d_relu6
103
+ elif fun_name == "elu" :
104
+ fx = elu
105
+ dfx = d_elu
106
+ elif fun_name == "silu" :
107
+ fx = silu
108
+ dfx = d_silu
109
+ elif fun_name == "gelu" :
110
+ fx = gelu
111
+ dfx = d_gelu
101
112
elif fun_name == "softplus" :
102
113
fx = softplus
103
114
dfx = d_softplus
@@ -127,35 +138,35 @@ def bkwta(x, nWTA=5): #5 10 15 #K=50):
127
138
return topK
128
139
129
140
@partial (jit , static_argnums = [2 , 3 , 4 ])
130
- def normalize_matrix (M , wnorm , order = 1 , axis = 0 , scale = 1. ):
141
+ def normalize_matrix (data , wnorm , order = 1 , axis = 0 , scale = 1. ):
131
142
"""
132
143
Normalizes the values in matrix to have a particular norm across each vector span.
133
144
134
145
Args:
135
- M : (2D) matrix to normalize
146
+ data : (2D) data matrix to normalize
136
147
137
- wnorm: target norm for each
148
+ wnorm: target norm for each row/column of data matrix
138
149
139
150
order: order of norm to use in normalization (Default: 1);
140
151
note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm
141
152
142
153
axis: 0 (apply to column vectors), 1 (apply to row vectors)
143
154
144
- scale: step modifier to produce the projected matrix
155
+ scale: step modifier to produce the projected matrix (Unused)
145
156
146
157
Returns:
147
158
a normalized value matrix
148
159
"""
149
160
if order == 2 : ## denominator is L2 norm
150
- wOrdSum = jnp .maximum (jnp .sqrt (jnp .sum (jnp .square (M ), axis = axis , keepdims = True )), 1e-8 )
161
+ wOrdSum = jnp .maximum (jnp .sqrt (jnp .sum (jnp .square (data ), axis = axis , keepdims = True )), 1e-8 )
151
162
else : ## denominator is L1 norm
152
- wOrdSum = jnp .maximum (jnp .sum (jnp .abs (M ), axis = axis , keepdims = True ), 1e-8 )
163
+ wOrdSum = jnp .maximum (jnp .sum (jnp .abs (data ), axis = axis , keepdims = True ), 1e-8 )
153
164
m = (wOrdSum == 0. ).astype (dtype = jnp .float32 )
154
165
wOrdSum = wOrdSum * (1. - m ) + m #wAbsSum[wAbsSum == 0.] = 1.
155
- _M = M * (wnorm / wOrdSum )
156
- #dM = ((wnorm/wOrdSum) - 1.) * M
157
- #_M = M + dM * scale
158
- return _M
166
+ _data = data * (wnorm / wOrdSum )
167
+ #d_data = ((wnorm/wOrdSum) - 1.) * data
168
+ #_data = data + d_data * scale
169
+ return _data
159
170
160
171
@jit
161
172
def clamp_min (x , min_val ):
@@ -529,7 +540,7 @@ def elu(x, alpha=1.):
529
540
return x * mask + ((jnp .exp (x ) - 1 ) * alpha ) * (1. - mask )
530
541
531
542
@jit
532
- def elu (x , alpha = 1. ):
543
+ def d_elu (x , alpha = 1. ):
533
544
mask = (x >= 0. )
534
545
return mask + (1. - mask ) * (jnp .exp (x ) * alpha )
535
546
0 commit comments