Skip to content

Commit

Permalink
nnet3: added an option to training to apply the update per minibatch …
Browse files Browse the repository at this point in the history
…(more accurate gradients).
  • Loading branch information
danpovey committed Aug 28, 2015
1 parent b1472d0 commit d689217
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/nnet3/nnet-optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct NnetOptimizeOptions {
"optimizations");
opts->Register("propagate-in-place", &propagate_in_place, "Set to false to "
"disable optimization that allows in-place propagation");
opts->Register("propagate-in-place", &propagate_in_place, "Set to false to "
opts->Register("backprop-in-place", &backprop_in_place, "Set to false to "
"disable optimization that allows in-place backprop");
opts->Register("remove-assignments", &remove_assignments, "Set to false to "
"disable optimization that removes redundant assignments");
Expand Down
9 changes: 8 additions & 1 deletion src/nnet3/nnet-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ void NnetTrainer::Train(const NnetExample &eg) {
config_.store_component_stats,
&request);
const NnetComputation *computation = compiler_.Compile(request);

const Nnet *const_nnet = (config_.update_per_minibatch ?
static_cast<const Nnet*>(nnet_->Copy()) :
nnet_);
NnetComputer computer(config_.compute_config, *computation,
*nnet_, nnet_);
*const_nnet, nnet_);
// give the inputs to the computer object.
computer.AcceptInputs(*nnet_, eg);
computer.Forward();

this->ProcessOutputs(eg, &computer);
computer.Backward();

if (config_.update_per_minibatch)
delete const_nnet;
}

void NnetTrainer::ProcessOutputs(const NnetExample &eg,
Expand Down
8 changes: 7 additions & 1 deletion src/nnet3/nnet-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ struct NnetTrainerOptions {
bool store_component_stats;
int32 print_interval;
bool debug_computation;
bool update_per_minibatch;
NnetOptimizeOptions optimize_config;
NnetComputeOptions compute_config;
NnetTrainerOptions():
zero_component_stats(true),
store_component_stats(false),
print_interval(100),
debug_computation(false) { }
debug_computation(false),
update_per_minibatch(false) { }
void Register(OptionsItf *opts) {
opts->Register("store-component-stats", &store_component_stats,
"If true, store activations and derivatives for nonlinear "
Expand All @@ -51,6 +53,10 @@ struct NnetTrainerOptions {
opts->Register("print-interval", &print_interval, "Interval (measured in "
"minibatches) after which we print out objective function "
"during training\n");
opts->Register("update-per-minibatch", &update_per_minibatch, "If true, "
"wait to apply model changes until the whole minibatch has "
"been processed (requires copying the model on each "
"minibatch ");

// register the optimization options with the prefix "optimization".
ParseOptions optimization_opts("optimization", opts);
Expand Down

0 comments on commit d689217

Please sign in to comment.