Skip to content

Commit 16de340

Browse files
committed
Merge pull request BVLC#3116 from ronghanghu/solver-refactor
Solver Refactor: Separate files and Change Solver's Type to String
2 parents 46dac40 + 9563537 commit 16de340

29 files changed

+1463
-1047
lines changed

docs/tutorial/solver.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ The responsibilities of learning are divided between the Solver for overseeing t
88

99
The Caffe solvers are:
1010

11-
- Stochastic Gradient Descent (`SGD`),
12-
- AdaDelta (`ADADELTA`),
13-
- Adaptive Gradient (`ADAGRAD`),
14-
- Adam (`ADAM`),
15-
- Nesterov's Accelerated Gradient (`NESTEROV`) and
16-
- RMSprop (`RMSPROP`)
11+
- Stochastic Gradient Descent (`type: "SGD"`),
12+
- AdaDelta (`type: "AdaDelta"`),
13+
- Adaptive Gradient (`type: "AdaGrad"`),
14+
- Adam (`type: "Adam"`),
15+
- Nesterov's Accelerated Gradient (`type: "Nesterov"`) and
16+
- RMSprop (`type: "RMSProp"`)
1717

1818
The solver
1919

@@ -51,7 +51,7 @@ The parameter update $$\Delta W$$ is formed by the solver from the error gradien
5151

5252
### SGD
5353

54-
**Stochastic gradient descent** (`solver_type: SGD`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
54+
**Stochastic gradient descent** (`type: "SGD"`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
5555
The **learning rate** $$ \alpha $$ is the weight of the negative gradient.
5656
The **momentum** $$ \mu $$ is the weight of the previous update.
5757

@@ -113,7 +113,7 @@ If learning diverges (e.g., you start to see very large or `NaN` or `inf` loss v
113113

114114
### AdaDelta
115115

116-
The **AdaDelta** (`solver_type: ADADELTA`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are
116+
The **AdaDelta** (`type: "AdaDelta"`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are
117117

118118
$$
119119
\begin{align}
@@ -125,7 +125,7 @@ E[g^2]_t &= \delta{E[g^2]_{t-1} } + (1-\delta)g_{t}^2
125125
\end{align}
126126
$$
127127

128-
and
128+
and
129129

130130
$$
131131
(W_{t+1})_i =
@@ -139,7 +139,7 @@ $$
139139

140140
### AdaGrad
141141

142-
The **adaptive gradient** (`solver_type: ADAGRAD`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
142+
The **adaptive gradient** (`type: "AdaGrad"`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
143143
Given the update information from all previous iterations $$ \left( \nabla L(W) \right)_{t'} $$ for $$ t' \in \{1, 2, ..., t\} $$,
144144
the update formulas proposed by [1] are as follows, specified for each component $$i$$ of the weights $$W$$:
145145

@@ -159,7 +159,7 @@ Note that in practice, for weights $$ W \in \mathcal{R}^d $$, AdaGrad implementa
159159

160160
### Adam
161161

162-
The **Adam** (`solver_type: ADAM`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are
162+
The **Adam** (`type: "Adam"`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are
163163

164164
$$
165165
(m_t)_i = \beta_1 (m_{t-1})_i + (1-\beta_1)(\nabla L(W_t))_i,\\
@@ -181,7 +181,7 @@ Kingma et al. [1] proposed to use $$\beta_1 = 0.9, \beta_2 = 0.999, \varepsilon
181181

182182
### NAG
183183

184-
**Nesterov's accelerated gradient** (`solver_type: NESTEROV`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
184+
**Nesterov's accelerated gradient** (`type: "Nesterov"`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
185185
Though the required assumptions to achieve the $$ \mathcal{O}(1/t^2) $$ convergence typically will not hold for deep networks trained with Caffe (e.g., due to non-smoothness and non-convexity), in practice NAG can be a very effective method for optimizing certain types of deep learning architectures, as demonstrated for deep MNIST autoencoders by Sutskever et al. [2].
186186

187187
The weight update formulas look very similar to the SGD updates given above:
@@ -206,10 +206,10 @@ What distinguishes the method from SGD is the weight setting $$ W $$ on which we
206206

207207
### RMSprop
208208

209-
The **RMSprop** (`solver_type: RMSPROP`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are
209+
The **RMSprop** (`type: "RMSProp"`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are
210210

211211
$$
212-
(v_t)_i =
212+
(v_t)_i =
213213
\begin{cases}
214214
(v_{t-1})_i + \delta, &(\nabla L(W_t))_i(\nabla L(W_{t-1}))_i > 0\\
215215
(v_{t-1})_i \cdot (1-\delta), & \text{else}

examples/mnist/lenet_adadelta_solver.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ snapshot: 5000
2020
snapshot_prefix: "examples/mnist/lenet_adadelta"
2121
# solver mode: CPU or GPU
2222
solver_mode: GPU
23-
solver_type: ADADELTA
23+
type: "AdaDelta"
2424
delta: 1e-6

examples/mnist/lenet_solver_adam.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ max_iter: 10000
2222
snapshot: 5000
2323
snapshot_prefix: "examples/mnist/lenet"
2424
# solver mode: CPU or GPU
25-
solver_type: ADAM
25+
type: "Adam"
2626
solver_mode: GPU

examples/mnist/lenet_solver_rmsprop.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ snapshot: 5000
2323
snapshot_prefix: "examples/mnist/lenet_rmsprop"
2424
# solver mode: CPU or GPU
2525
solver_mode: GPU
26-
solver_type: RMSPROP
26+
type: "RMSProp"
2727
rms_decay: 0.98

examples/mnist/mnist_autoencoder_solver_adadelta.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ snapshot: 10000
1616
snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train"
1717
# solver mode: CPU or GPU
1818
solver_mode: GPU
19-
solver_type: ADADELTA
19+
type: "AdaDelta"

examples/mnist/mnist_autoencoder_solver_adagrad.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ snapshot: 10000
1414
snapshot_prefix: "examples/mnist/mnist_autoencoder_adagrad_train"
1515
# solver mode: CPU or GPU
1616
solver_mode: GPU
17-
solver_type: ADAGRAD
17+
type: "AdaGrad"

examples/mnist/mnist_autoencoder_solver_nesterov.prototxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"
1717
momentum: 0.95
1818
# solver mode: CPU or GPU
1919
solver_mode: GPU
20-
solver_type: NESTEROV
20+
type: "Nesterov"

include/caffe/caffe.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
#include "caffe/parallel.hpp"
1414
#include "caffe/proto/caffe.pb.h"
1515
#include "caffe/solver.hpp"
16+
#include "caffe/solver_factory.hpp"
1617
#include "caffe/util/benchmark.hpp"
1718
#include "caffe/util/io.hpp"
19+
#include "caffe/util/upgrade_proto.hpp"
1820
#include "caffe/vision_layers.hpp"
1921

2022
#endif // CAFFE_CAFFE_HPP_

include/caffe/sgd_solvers.hpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#ifndef CAFFE_SGD_SOLVERS_HPP_
2+
#define CAFFE_SGD_SOLVERS_HPP_
3+
4+
#include <string>
5+
#include <vector>
6+
7+
#include "caffe/solver.hpp"
8+
9+
namespace caffe {
10+
11+
/**
12+
* @brief Optimizes the parameters of a Net using
13+
* stochastic gradient descent (SGD) with momentum.
14+
*/
15+
template <typename Dtype>
16+
class SGDSolver : public Solver<Dtype> {
17+
public:
18+
explicit SGDSolver(const SolverParameter& param)
19+
: Solver<Dtype>(param) { PreSolve(); }
20+
explicit SGDSolver(const string& param_file)
21+
: Solver<Dtype>(param_file) { PreSolve(); }
22+
virtual inline const char* type() const { return "SGD"; }
23+
24+
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
25+
26+
protected:
27+
void PreSolve();
28+
Dtype GetLearningRate();
29+
virtual void ApplyUpdate();
30+
virtual void Normalize(int param_id);
31+
virtual void Regularize(int param_id);
32+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
33+
virtual void ClipGradients();
34+
virtual void SnapshotSolverState(const string& model_filename);
35+
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
36+
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
37+
virtual void RestoreSolverStateFromHDF5(const string& state_file);
38+
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
39+
// history maintains the historical momentum data.
40+
// update maintains update related data and is not needed in snapshots.
41+
// temp maintains other information that might be needed in computation
42+
// of gradients/updates and is not needed in snapshots
43+
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
44+
45+
DISABLE_COPY_AND_ASSIGN(SGDSolver);
46+
};
47+
48+
template <typename Dtype>
49+
class NesterovSolver : public SGDSolver<Dtype> {
50+
public:
51+
explicit NesterovSolver(const SolverParameter& param)
52+
: SGDSolver<Dtype>(param) {}
53+
explicit NesterovSolver(const string& param_file)
54+
: SGDSolver<Dtype>(param_file) {}
55+
virtual inline const char* type() const { return "Nesterov"; }
56+
57+
protected:
58+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
59+
60+
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
61+
};
62+
63+
template <typename Dtype>
64+
class AdaGradSolver : public SGDSolver<Dtype> {
65+
public:
66+
explicit AdaGradSolver(const SolverParameter& param)
67+
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
68+
explicit AdaGradSolver(const string& param_file)
69+
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
70+
virtual inline const char* type() const { return "AdaGrad"; }
71+
72+
protected:
73+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
74+
void constructor_sanity_check() {
75+
CHECK_EQ(0, this->param_.momentum())
76+
<< "Momentum cannot be used with AdaGrad.";
77+
}
78+
79+
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
80+
};
81+
82+
83+
template <typename Dtype>
84+
class RMSPropSolver : public SGDSolver<Dtype> {
85+
public:
86+
explicit RMSPropSolver(const SolverParameter& param)
87+
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
88+
explicit RMSPropSolver(const string& param_file)
89+
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
90+
virtual inline const char* type() const { return "RMSProp"; }
91+
92+
protected:
93+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
94+
void constructor_sanity_check() {
95+
CHECK_EQ(0, this->param_.momentum())
96+
<< "Momentum cannot be used with RMSProp.";
97+
CHECK_GE(this->param_.rms_decay(), 0)
98+
<< "rms_decay should lie between 0 and 1.";
99+
CHECK_LT(this->param_.rms_decay(), 1)
100+
<< "rms_decay should lie between 0 and 1.";
101+
}
102+
103+
DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
104+
};
105+
106+
template <typename Dtype>
107+
class AdaDeltaSolver : public SGDSolver<Dtype> {
108+
public:
109+
explicit AdaDeltaSolver(const SolverParameter& param)
110+
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
111+
explicit AdaDeltaSolver(const string& param_file)
112+
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
113+
virtual inline const char* type() const { return "AdaDelta"; }
114+
115+
protected:
116+
void AdaDeltaPreSolve();
117+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
118+
119+
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
120+
};
121+
122+
/**
123+
* @brief AdamSolver, an algorithm for first-order gradient-based optimization
124+
* of stochastic objective functions, based on adaptive estimates of
125+
* lower-order moments. Described in [1].
126+
*
127+
* [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
128+
* arXiv preprint arXiv:1412.6980v8 (2014).
129+
*/
130+
template <typename Dtype>
131+
class AdamSolver : public SGDSolver<Dtype> {
132+
public:
133+
explicit AdamSolver(const SolverParameter& param)
134+
: SGDSolver<Dtype>(param) { AdamPreSolve();}
135+
explicit AdamSolver(const string& param_file)
136+
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
137+
virtual inline const char* type() const { return "Adam"; }
138+
139+
protected:
140+
void AdamPreSolve();
141+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
142+
143+
DISABLE_COPY_AND_ASSIGN(AdamSolver);
144+
};
145+
146+
} // namespace caffe
147+
148+
#endif // CAFFE_SGD_SOLVERS_HPP_

0 commit comments

Comments
 (0)