diff --git a/src/net/NetGraph.cpp b/src/net/NetGraph.cpp index 4360d72..4372494 100644 --- a/src/net/NetGraph.cpp +++ b/src/net/NetGraph.cpp @@ -300,6 +300,10 @@ void NetGraph::FeedForward(NetGraphNode* node) { for (NetGraphConnection connection : node->input_connections) FeedForward(connection.node); +#ifdef LAYERTIME + auto t_begin = std::chrono::system_clock::now(); +#endif + PrepareNode(node); // Call the Layer::FeedForward method and set the visited flag node->layer->FeedForward(); @@ -318,6 +322,12 @@ void NetGraph::FeedForward(NetGraphNode* node) { } node->flag_ff_visited = true; + +#ifdef LAYERTIME + auto t_end = std::chrono::system_clock::now(); + std::chrono::duration pass_duration = t_end - t_begin; + LOGINFO << "FeedFwd Layer " << node->unique_name << " (" << node->layer->GetLayerDescription() << ") time:\t" << pass_duration.count() << "s"; +#endif } } @@ -344,11 +354,21 @@ void NetGraph::BackPropagate(NetGraphNode* node) { for (NetGraphConnection connection : node->input_connections) do_backprop |= connection.backprop; +#ifdef LAYERTIME + auto t_begin = std::chrono::system_clock::now(); +#endif + PrepareNode(node); node->layer->SetBackpropagationEnabled(do_backprop); // Call the Layer::FeedForward method and set the visited flag node->layer->BackPropagate(); node->flag_bp_visited = true; + +#ifdef LAYERTIME + auto t_end = std::chrono::system_clock::now(); + std::chrono::duration pass_duration = t_end - t_begin; + LOGINFO << "BackProp Layer " << node->unique_name << " (" << node->layer->GetLayerDescription() << ") time:\t" << pass_duration.count() << "s"; +#endif } } diff --git a/src/util/Init.cpp b/src/util/Init.cpp index 171500a..408fd79 100644 --- a/src/util/Init.cpp +++ b/src/util/Init.cpp @@ -87,7 +87,7 @@ void System::Init(int requested_log_level) { } else log_level = requested_log_level; - LOGINFO << "CN24 v2.0.1 at " STRING_SHA1; + LOGINFO << "CN24 v2.0.2 at " STRING_SHA1; LOGINFO << "Copyright (C) 2015 Clemens-Alexander Brust"; LOGINFO << "For licensing information, see the LICENSE" << " file included with this project."; diff --git a/tools/classifyImage.cpp b/tools/classifyImage.cpp index 37a9e33..cc05e4a 100644 --- a/tools/classifyImage.cpp +++ b/tools/classifyImage.cpp @@ -62,6 +62,8 @@ int main (int argc, char* argv[]) { // Rescale image unsigned int width = original_data_tensor.width(); unsigned int height = original_data_tensor.height(); + unsigned int original_width = original_data_tensor.width(); + unsigned int original_height = original_data_tensor.height(); if(width & 1) width++; if(height & 1) @@ -78,12 +80,37 @@ int main (int argc, char* argv[]) { height+=4; Conv::Tensor data_tensor(1, width, height, original_data_tensor.maps()); + Conv::Tensor helper_tensor(1, width, height, 2); data_tensor.Clear(); + helper_tensor.Clear(); + + // Copy sample because data_tensor may be slightly larger Conv::Tensor::CopySample(original_data_tensor, 0, data_tensor, 0); + // Initialize helper (spatial prior) tensor + + // Write spatial prior data to helper tensor + for (unsigned int y = 0; y < original_height; y++) { + for (unsigned int x = 0; x < original_width; x++) { + *helper_tensor.data_ptr(x, y, 0, 0) = ((Conv::datum)x) / ((Conv::datum)original_width - 1); + *helper_tensor.data_ptr(x, y, 1, 0) = ((Conv::datum)y) / ((Conv::datum)original_height - 1); + } + for (unsigned int x = original_width; x < width; x++) { + *helper_tensor.data_ptr(x, y, 0, 0) = 0; + *helper_tensor.data_ptr(x, y, 1, 0) = 0; + } + } + for (unsigned int y = original_height; y < height; y++) { + for (unsigned int x = 0; x < height; x++) { + *helper_tensor.data_ptr(x, y, 0, 0) = 0; + *helper_tensor.data_ptr(x, y, 1, 0) = 0; + } + } + + // Assemble net Conv::NetGraph graph; - Conv::InputLayer input_layer(data_tensor); + Conv::InputLayer input_layer(data_tensor, helper_tensor); Conv::NetGraphNode input_node(&input_layer); input_node.is_input = true; diff --git a/tools/trainNetwork.cpp b/tools/trainNetwork.cpp index a682fc8..0f55e60 100644 --- a/tools/trainNetwork.cpp +++ b/tools/trainNetwork.cpp @@ -35,8 +35,8 @@ int main (int argc, char* argv[]) { const Conv::datum it_factor = 0.01; #else const Conv::datum it_factor = 1; - const Conv::datum loss_sampling_p = 0.5; #endif + const Conv::datum loss_sampling_p = 0.5; if(argc > 1) { if(std::string(argv[1]).compare("-v") == 0) {