-
Notifications
You must be signed in to change notification settings - Fork 0
/
q1_softmax_sol.py
executable file
·39 lines (33 loc) · 1.12 KB
/
q1_softmax_sol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Solution to the coding part for question (1) of CS224D.
"""
import numpy as np
def softmax_sol(x):
"""
Compute the softmax function for each row of the input x.
It is crucial that this function is optimized for speed because
it will be used frequently in later code.
You might find numpy functions np.exp, np.sum, np.reshape,
np.max, and numpy broadcasting useful for this task. (numpy
broadcasting documentation:
http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
You should also make sure that your code works for one
dimensional inputs (treat the vector as a row), you might find
it helpful for your later problems.
You must implement the optimization in problem 1(a) of the
written assignment!
"""
### YOUR CODE HERE
if len(x.shape) > 1:
tmp = np.max(x, axis = 1)
x -= tmp.reshape((x.shape[0], 1))
x = np.exp(x)
tmp = np.sum(x, axis = 1)
x /= tmp.reshape((x.shape[0], 1))
else:
tmp = np.max(x)
x -= tmp
x = np.exp(x)
tmp = np.sum(x)
x /= tmp
### END YOUR CODE
return x