-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate.cpp
71 lines (63 loc) · 3.05 KB
/
evaluate.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Echo state network evaluation. //
#include "analysis.hpp"
#include "argument_utils.hpp"
#include "benchmarks.hpp"
#include "lcnn.hpp"
#include "simple_esn.hpp"
#include <iostream>
namespace po = boost::program_options;
/// Evaluate the net on the given benchmark.
///
/// \param net_factory The network to be tested.
/// \param n_evals The number of complete reevaluations of the provided net.
template <typename NetFactory>
std::vector<double>
evaluate(NetFactory net_factory, std::unique_ptr<esn::benchmark_set_base> bench, long n_evals)
{
int af_device = af::getDevice();
// Evaluate the individual repeats in parallel.
std::vector<double> results(n_evals);
std::for_each(std::execution::par, results.begin(), results.end(), [&](double& r) {
// We need to make sure the device is set properly, otherwise
// it sometimes fails on XID errors.
af::setDevice(af_device);
auto net = net_factory(bench->n_ins(), bench->n_outs());
r = bench->evaluate(*net, esn::global_prng);
});
return results;
}
int main(int argc, char* argv[])
{
po::options_description arg_desc{"Generic options"};
arg_desc.add_options() //
("help", //
"Produce help message.") //
("gen.net-type", po::value<std::string>()->default_value("lcnn"), //
"Network type, one of {simple-esn, lcnn}.") //
("gen.benchmark-set", po::value<std::string>()->default_value("narma10"), //
"Benchmark set to be evaluated.") //
("gen.n-evals", po::value<long>()->default_value(3), //
"The number of complete reevaluations of the provided set of parameters.") //
("gen.af-device", po::value<int>()->default_value(0), //
"ArrayFire device to be used."); //
arg_desc.add(esn::benchmark_arg_description());
po::variables_map args = esn::parse_conditional(
argc, argv, arg_desc,
{{"gen.net-type", //
{{"lcnn", esn::lcnn_arg_description()}, //
{"simple-esn", esn::esn_arg_description()}}}}); //
af::setDevice(args.at("gen.af-device").as<int>());
af::info();
std::cout << std::endl;
std::unique_ptr<esn::benchmark_set_base> bench = esn::make_benchmark(args);
auto net_factory = [&](auto... fwd) { return esn::make_net(fwd..., args, esn::global_prng); };
std::string name = args.at("gen.benchmark-set").as<std::string>() + " "
+ args.at("bench.error-measure").as<std::string>();
af::timer::start();
esn::benchmark_results results;
results.insert(
name, evaluate(net_factory, std::move(bench), args.at("gen.n-evals").as<long>()));
std::cout << results << std::endl;
std::cout << "elapsed time: " << af::timer::stop() << std::endl;
return 0;
}