Skip to content

Commit

Permalink
update log_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
konas122 committed Feb 27, 2024
1 parent 13ce1ff commit c35fd59
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
40 changes: 27 additions & 13 deletions dazero/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ def relu(x):
return ReLU()(x)


class LeakyReLU(Function):
def __init__(self, slope):
self.slope = slope

def forward(self, x):
y = x.copy()
y[x <= 0] *= self.slope
return y

def backward(self, gy):
x, = self.inputs
mask = (x.data > 0).astype(gy.dtype)
mask[mask <= 0] = self.slope
gx = gy * mask
return gx

def leaky_relu(x, slope=0.2):
return LeakyReLU(slope)(x)


def softmax_simple(x, axis=1):
x = as_variable(x)
y = exp(x)
Expand Down Expand Up @@ -105,19 +125,13 @@ def __init__(self, axis=1):
self.axis = axis

def forward(self, x):
xp = cuda.get_array_module(x)
y1 = x - x.max(axis=self.axis, keepdims=True)
y2 = xp.exp(y1)
y2 = y2.sum(axis=self.axis, keepdims=True)
y2 = xp.log(y2)
y = y1 - y2
log_z, _ = utils.logsumexp(x, self.axis)
y = x - log_z
return y

def backward(self, gy):
x = self.inputs[0]
xp = cuda.get_array_module(gy)
gx = Variable(xp.ones_like(x)) - softmax(x, self.axis)
gx *= gy
y = self.outputs[0]()
gx = gy - exp(y) * gy.sum(axis=self.axis, keepdims=True)
return gx

def log_softmax(x, axis=1):
Expand Down Expand Up @@ -229,10 +243,10 @@ def softmax_cross_entropy_simple(x, t):
class SoftmaxCrossEntropy(Function):
def forward(self, x, t):
N = x.shape[0]
log_z = utils.logsumexp(x, axis=1)
log_z, xp = utils.logsumexp(x, axis=1)
log_p = x - log_z
log_p = log_p[np.arange(N), t.ravel()]
y = -log_p.sum() / np.float32(N)
log_p = log_p[xp.arange(N), t.ravel()]
y = -log_p.sum() / xp.float32(N)
return y

def backward(self, gy):
Expand Down
2 changes: 1 addition & 1 deletion dazero/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, dim, heads, mask, ffn_hidden_mult=4, dropout=0.0):
F.ReLU(),
Linear(ffn_hidden_mult * dim, dim)
)

def forward(self, x):
attended = self.attention(x)
x = self.norm1(attended + x)
Expand Down
2 changes: 1 addition & 1 deletion dazero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def logsumexp(x, axis=1):
s = y.sum(axis=axis, keepdims=True)
xp.log(s, out=s)
m += s
return m
return m, xp


def max_backward_shape(x, axis):
Expand Down

0 comments on commit c35fd59

Please sign in to comment.