diff --git a/README.md b/README.md index 7daae99..8b50097 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,16 @@ **OpenEM** is a library that provides advanced video analytics for fisheries electronic monitoring (EM) data. It currently supports detection, -classification and measurement of fish during landing or discard. This -functionality is currently only available via a deployment library with +classification, counting and measurement of fish during landing or discard. +This functionality is currently only available via a deployment library with pretrained models available in our example data (see tutorial). The base library is written in C++, with bindings available for both Python and C#. Examples are included for all three languages. There are immediate plans to develop a training library so that users -can build their own models on their own data, and to add counting -functionality. Currently builds have only been tested on Windows. We plan -to support both Ubuntu and macOS in the future. +can build their own models on their own data. Currently builds have only +been tested on Windows. We plan to support both Ubuntu and macOS in the +future. Click the image below to see a video of OpenEM in action: @@ -21,7 +21,7 @@ Click the image below to see a video of OpenEM in action: ### Contents -* [Building and Testing](doc/build.md) +* [Building](doc/build.md) * [Tutorial](doc/tutorial.md) * [Data Collection Guidelines](doc/data_collection.md) * [Annotation Guidelines](doc/annotation.md) diff --git a/deploy/bindings/openem.i b/deploy/bindings/openem.i index cf568c3..b7173ef 100644 --- a/deploy/bindings/openem.i +++ b/deploy/bindings/openem.i @@ -11,9 +11,10 @@ #include "find_ruler.h" #include "detect.h" #include "classify.h" +#include "count.h" %} -%include "pointer.i" +%include "cpointer.i" %include "stdint.i" %include "std_array.i" %include "std_string.i" @@ -22,6 +23,7 @@ namespace std { %template(VectorDouble) vector; %template(VectorFloat) vector; + %template(VectorInt) vector; %template(VectorVectorFloat) vector>; %template(VectorVectorVectorFloat) vector>>; %template(VectorUint8) vector; @@ -31,6 +33,11 @@ namespace std { %template(VectorRect) vector>; %template(VectorVectorRect) vector>>; %template(VectorImage) vector; + %template(VectorDetection) vector; + %template(VectorVectorDetection) vector>; + %template(ArrayFloat3) array; + %template(VectorClassification) vector; + %template(VectorVectorClassification) vector>; }; %include "error_codes.h" @@ -39,4 +46,5 @@ namespace std { %include "find_ruler.h" %include "detect.h" %include "classify.h" +%include "count.h" diff --git a/deploy/include/classify.h b/deploy/include/classify.h index f8f077d..d99e106 100644 --- a/deploy/include/classify.h +++ b/deploy/include/classify.h @@ -28,6 +28,22 @@ namespace openem { namespace classify { +/// Contains classification results. +struct Classification { + /// Species scores. + /// + /// Elements correspond to the species used to train the loaded model. + std::vector species; + + /// Cover scores. + /// + /// Elements correspond to the following: + /// * No fish in the image. + /// * Fish is covered by a hand. + /// * Fish is not covered. + std::array cover; +}; + /// Class for determining fish species and whether the image is covered /// by a hand, clear, or not a fish. class Classifier { @@ -63,17 +79,11 @@ class Classifier { /// @return Error code. ErrorCode AddImage(const Image& image); - /// Determines fish species and whether the fish is covered by a hand, - /// clear, or not a fish. - /// @param scores Vector of double vectors. Each double vector - /// corresponds to one of the images in the image queue. The first - /// three numbers in the double vector correspond to: - /// * No fish in the image. - /// * Fish is covered by a hand. - /// * Fish is not covered. - /// The remaining vector elements correspond to the species used to train - /// the loaded model. - ErrorCode Process(std::vector>* scores); + /// Performs classification on each image that was added to the + /// processing queue with AddImage. + /// @param results Vector of classification results. + /// @return Error code. + ErrorCode Process(std::vector* results); private: /// Forward declaration of implementation class. class ClassifierImpl; diff --git a/deploy/include/count.h b/deploy/include/count.h new file mode 100644 index 0000000..60ac025 --- /dev/null +++ b/deploy/include/count.h @@ -0,0 +1,80 @@ +/// @file +/// @brief Interface for counting fish. +/// @copyright Copyright (C) 2018 CVision AI. +/// @license This file is part of OpenEM, released under GPLv3. +// OpenEM is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with OpenEM. If not, see . + +#ifndef OPENEM_DEPLOY_COUNT_H_ +#define OPENEM_DEPLOY_COUNT_H_ + +#include +#include +#include + +#include "detect.h" +#include "classify.h" +#include "error_codes.h" + +namespace openem { +namespace count { + +/// Class for finding keyframes. +class KeyframeFinder { + public: + /// Constructor. + KeyframeFinder(); + + /// Destructor. + ~KeyframeFinder(); + + /// Initializes the keyframe finder. + /// @param model_path Path to protobuf file containing model. + /// @param img_width Width of image input to detector. + /// @param img_height Height of image input to detector. + /// @param gpu_fraction Fraction of GPU memory that may be allocated to + /// this object. + /// @return Error code. + ErrorCode Init( + const std::string& model_path, + int img_width, + int img_height, + double gpu_fraction=1.0); + + /// Finds keyframes in a given sequence. + /// @param classifications Sequence of outputs from classifier. Outer + /// vector corresponds to frames, inner corresponds to detections in each + /// frame. + /// @param detections Sequence of outputs from detector. Outer vector + /// corresponds to frames, inner corresponds to detections in each frame. + /// @param keyframes Vector of keyframes in the sequence. The length of + /// this vector is the number of fish in the sequence. The values + /// are the indices of the classification and detection vectors. + /// @return Error code. + ErrorCode Process( + const std::vector>& classifications, + const std::vector>& detections, + std::vector* keyframes); + private: + /// Forward declaration of implementation class. + class KeyframeFinderImpl; + + /// Pointer to implementation. + std::unique_ptr impl_; +}; + +} // namespace count +} // namespace openem + +#endif // OPENEM_DEPLOY_COUNT_H_ + diff --git a/deploy/include/detail/model.h b/deploy/include/detail/model.h index 95265c0..060abc3 100644 --- a/deploy/include/detail/model.h +++ b/deploy/include/detail/model.h @@ -36,6 +36,53 @@ class Model { /// Constructor. Model(); + /// Destructor. + ~Model(); + + /// Loads a model from a protobuf file and initializes the tensorflow + /// session. + /// @param model_path Path to protobuf file containing the model. + /// @param gpu_fraction Fraction fo GPU allowed to be used by this object. + /// @return Error code. + ErrorCode Init(const std::string& model_path, double gpu_fraction); + + /// Returns input size. + /// @return Input size. + std::vector InputSize(); + + /// Returns whether the model has been initialized. + /// @retur True if initialized. + bool Initialized(); + + /// Processes the model on the current batch. + /// @param input Input tensor. + /// @param input_name Name of input tensor. + /// @param output_names Name of output tensors. + /// @param outputs Output of the model. + /// @return Error code. + ErrorCode Process( + const tensorflow::Tensor& input, + const std::string& input_name, + const std::vector& output_names, + std::vector* outputs); + private: + /// Forward declaration of class implementation. + class ModelImpl; + + /// Pointer to implementation. + std::unique_ptr impl_; +}; + +/// Model that accepts image data as input. This class is intended to be +/// an implementation detail only. +class ImageModel { + public: + /// Constructor. + ImageModel(); + + /// Destructor. + ~ImageModel(); + /// Loads a model from a protobuf file and initializes the tensorflow /// session. /// @param model_path Path to protobuf file containing the model. @@ -58,32 +105,20 @@ class Model { std::function preprocess); /// Processes the model on the current batch. - /// @param outputs Output of the model. /// @param input_name Name of input tensor. /// @param output_names Name of output tensors. + /// @param outputs Output of the model. /// @return Error code. ErrorCode Process( - std::vector* outputs, const std::string& input_name, - const std::vector& output_names); + const std::vector& output_names, + std::vector* outputs); private: - /// Tensorflow session. - std::unique_ptr session_; - - /// Input image width. - int width_; - - /// Input image height. - int height_; - - /// Indicates whether the model has been initialized. - bool initialized_; - - /// Queue of futures containing preprocessed images. - std::queue> preprocessed_; + /// Forward declaration of class implementation. + class ImageModelImpl; - /// Mutex for handling concurrent access to image queue. - std::mutex mutex_; + /// Pointer to implementation. + std::unique_ptr impl_; }; } // namespace detail diff --git a/deploy/include/detail/util.h b/deploy/include/detail/util.h index 3d8f3f5..0b6e0be 100644 --- a/deploy/include/detail/util.h +++ b/deploy/include/detail/util.h @@ -48,11 +48,19 @@ ErrorCode GetSession(tensorflow::Session** session, double gpu_fraction); /// Gets graph input size. /// @param graph_def Graph definition. +/// @param input_size Input size vector. +/// @return Error code. +ErrorCode InputSize( + const tensorflow::GraphDef& graph_def, + std::vector* input_size); + +/// Gets image size from graph. +/// @param input_size Input dimensions output from InputSize. /// @param width Width dimension of input layer. /// @param height Height dimension of input layer. /// @return Error code. -ErrorCode InputSize( - const tensorflow::GraphDef& graph_def, +ErrorCode ImageSize( + const std::vector& input_size, int* width, int* height); diff --git a/deploy/include/detect.h b/deploy/include/detect.h index a81a7f0..fb8debc 100644 --- a/deploy/include/detect.h +++ b/deploy/include/detect.h @@ -28,6 +28,18 @@ namespace openem { namespace detect { +/// Contains information for a single detection. +struct Detection { + /// Location of the detection in the frame. + Rect location; + + /// Confidence score. + float confidence; + + /// Species index based on highest confidence. + int species; +}; + /// Class for detecting fish in images. class Detector { public: @@ -64,7 +76,9 @@ class Detector { /// Finds fish in batched images by performing object detection /// with Single Shot Detector (SSD). - ErrorCode Process(std::vector>* detections); + /// @param detections Vector of detections for each image. + /// @return Error code. + ErrorCode Process(std::vector>* detections); private: /// Forward declaration of implementation class. class DetectorImpl; diff --git a/deploy/include/error_codes.h b/deploy/include/error_codes.h index 7f64d64..7b6f701 100644 --- a/deploy/include/error_codes.h +++ b/deploy/include/error_codes.h @@ -36,7 +36,10 @@ enum ErrorCode { kErrorNumChann, ///< Invalid number of image channels. kErrorVidFileOpen, ///< Failed to open video file. kErrorVidNotOpen, ///< Tried to read from unopened video. - kErrorVidFrame ///< Failed to read video frame. + kErrorVidFrame, ///< Failed to read video frame. + kErrorLenMismatch, ///< Mismatch in sequence lengths. + kErrorNumInputDims, ///< Unexpected number of input dimensions. + kErrorBadSeqLength ///< Wrong sequence length for input. }; } // namespace openem diff --git a/deploy/src/CMakeLists.txt b/deploy/src/CMakeLists.txt index 315543e..b1b4ed5 100644 --- a/deploy/src/CMakeLists.txt +++ b/deploy/src/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(openem find_ruler.cc detect.cc classify.cc + count.cc util.cc model.cc video.cc diff --git a/deploy/src/classify.cc b/deploy/src/classify.cc index 0e0d869..b8c5ba1 100644 --- a/deploy/src/classify.cc +++ b/deploy/src/classify.cc @@ -27,7 +27,7 @@ namespace classify { class Classifier::ClassifierImpl { public: /// Stores and processes the model. - detail::Model model_; + detail::ImageModel model_; }; Classifier::Classifier() : impl_(new ClassifierImpl()) {} @@ -58,29 +58,30 @@ ErrorCode Classifier::AddImage(const Image& image) { return impl_->model_.AddImage(*mat, preprocess); } -ErrorCode Classifier::Process(std::vector>* scores) { +ErrorCode Classifier::Process(std::vector* results) { // Run the model. std::vector outputs; ErrorCode status = impl_->model_.Process( - &outputs, "data", - {"cat_species_1:0", "cat_cover_1:0"}); + {"cat_species_1:0", "cat_cover_1:0"}, + &outputs); if (status != kSuccess) return status; // Convert to mat vector. std::vector species; detail::TensorToMatVec(outputs[0], &species, 1.0, 0.0, CV_32F); - std::vector quality; - detail::TensorToMatVec(outputs[1], &quality, 1.0, 0.0, CV_32F); + std::vector cover; + detail::TensorToMatVec(outputs[1], &cover, 1.0, 0.0, CV_32F); // Clear input results. - scores->clear(); + results->clear(); // Iterate through results for each image. - for(int i = 0; i < species.size(); ++i) { - std::vector vec(quality[i].begin(), quality[i].end()); - vec.insert(vec.end(), species[i].begin(), species[i].end()); - scores->push_back(std::move(vec)); + for (int i = 0; i < species.size(); ++i) { + Classification c; + c.species.assign(species[i].begin(), species[i].end()); + for (int j = 0; j < 3; ++j) c.cover[j] = cover[i].at(j); + results->push_back(std::move(c)); } return kSuccess; } diff --git a/deploy/src/count.cc b/deploy/src/count.cc new file mode 100644 index 0000000..00b4277 --- /dev/null +++ b/deploy/src/count.cc @@ -0,0 +1,138 @@ +/// @file +/// @brief Implementation for counting fish. +/// @copyright Copyright (C) 2018 CVision AI. +/// @license This file is part of OpenEM, released under GPLv3. +// OpenEM is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with OpenEM. If not, see . + +#include "count.h" + +#include +#include "detail/model.h" +#include "detail/util.h" + +namespace openem { +namespace count { + +namespace tf = tensorflow; + +/// Implementation details for KeyframeFinder. +class KeyframeFinder::KeyframeFinderImpl { + public: + /// Stores and processes the model. + detail::Model model_; + + /// Stores image width. Needed for feature normalization. + float width_; + + /// Stores image height. Needed for feature normalization. + float height_; + + /// Model input size. + std::vector input_size_; +}; + +KeyframeFinder::KeyframeFinder() : impl_(new KeyframeFinderImpl()) {} + +KeyframeFinder::~KeyframeFinder() {} + +ErrorCode KeyframeFinder::Init( + const std::string& model_path, + int img_width, + int img_height, + double gpu_fraction) { + impl_->width_ = static_cast(img_width); + impl_->height_ = static_cast(img_height); + ErrorCode status = impl_->model_.Init(model_path, gpu_fraction); + if (status != kSuccess) return status; + impl_->input_size_ = impl_->model_.InputSize(); + if (impl_->input_size_.size() != 3) return kErrorNumInputDims; + return kSuccess; +} + +ErrorCode KeyframeFinder::Process( + const std::vector>& classifications, + const std::vector>& detections, + std::vector* keyframes) { + constexpr float kKeyframeThresh = 0.2; + constexpr int kKeyframeOffset = 32; + + // Get tensor size and do size checks. + if (classifications.size() != detections.size()) return kErrorLenMismatch; + int seq_len = impl_->input_size_[1]; + int fea_len = impl_->input_size_[2]; + tf::TensorShape shape({1, seq_len, fea_len}); + tf::Tensor seq_tensor = tf::Input::Initializer(0.0f, shape).tensor; + auto seq = seq_tensor.flat(); + + // Copy data from each detection/classification into sequence tensor. + for (int n = 0, offset = 0; n < detections.size(); ++n) { + + // If this frame has no detections, continue. + if (detections[n].empty()) { + offset += fea_len; + continue; + } + + // More size checking. + const auto& c = classifications[n][0]; + const auto& d = detections[n][0]; + int num_species = static_cast(c.species.size()); + int num_cover = static_cast(c.cover.size()); + int num_loc = static_cast(d.location.size()); + int num_fea = num_species + num_cover + num_loc + 2; + if (fea_len != (2 * num_fea)) return kErrorBadSeqLength; + + // Normalize the bounding boxes to image size. + std::array norm_loc = { + static_cast(d.location[0]) / impl_->width_, + static_cast(d.location[1]) / impl_->height_, + static_cast(d.location[2]) / impl_->width_, + static_cast(d.location[3]) / impl_->height_}; + + // Copy twice for now, this is how the existing model works. + for (int m = 0; m < 2; m++) { + std::copy_n(c.species.data(), num_species, seq.data() + offset); + offset += num_species; + std::copy_n(c.cover.data(), num_cover, seq.data() + offset); + offset += num_cover; + std::copy_n(norm_loc.data(), num_loc, seq.data() + offset); + offset += num_loc; + seq(offset) = d.confidence; + offset++; + seq(offset) = static_cast(d.species); + offset++; + } + } + // Process the model. + std::vector outputs; + ErrorCode status = impl_->model_.Process( + seq_tensor, + "input_1", + {"cumsum_values_1:0"}, + &outputs); + + // Find values over a threshold. + keyframes->clear(); + auto out = outputs[0].tensor(); + for (int i = 0; i < detections.size(); ++i) { + if (out(0, i) > kKeyframeThresh) { + keyframes->push_back(i + kKeyframeOffset); + } + } + return kSuccess; +} + +} // namespace count +} // namespace openem + diff --git a/deploy/src/detect.cc b/deploy/src/detect.cc index d9ba55e..8ccfa24 100644 --- a/deploy/src/detect.cc +++ b/deploy/src/detect.cc @@ -46,7 +46,7 @@ std::vector DecodeBoxes( class Detector::DetectorImpl { public: /// Stores and processes the model. - detail::Model model_; + detail::ImageModel model_; /// Stores image scale factors. std::vector> img_scale_; @@ -130,16 +130,15 @@ ErrorCode Detector::AddImage(const Image& image) { return impl_->model_.AddImage(*mat, preprocess); } -ErrorCode Detector::Process( - std::vector>>* detections) { +ErrorCode Detector::Process(std::vector>* detections) { constexpr int kBackgroundClass = 0; // Run the model. std::vector outputs; ErrorCode status = impl_->model_.Process( - &outputs, "input_1", - {"output_node0:0"}); + {"output_node0:0"}, + &outputs); if (status != kSuccess) return status; // Convert to mat vector. @@ -164,18 +163,22 @@ ErrorCode Detector::Process( anchors = p(cv::Range::all(), cv::Range(conf_stop, anc_stop)); variances = p(cv::Range::all(), cv::Range(anc_stop, var_stop)); boxes = DecodeBoxes(loc, anchors, variances, impl_->model_.ImageSize()); - std::vector> dets; + std::vector dets; for (int c = 0; c < conf.cols; ++c) { if (c == kBackgroundClass) continue; cv::Mat& c_conf = conf.col(c); scores.assign(c_conf.begin(), c_conf.end()); cv::dnn::NMSBoxes(boxes, scores, 0.01, 0.45, indices, 1.0, 200); for (int idx : indices) { - dets.emplace_back(std::array{ - int(boxes[idx].x * impl_->img_scale_[i].first), - int(boxes[idx].y * impl_->img_scale_[i].second), - int(boxes[idx].width * impl_->img_scale_[i].first), - int(boxes[idx].height * impl_->img_scale_[i].second)}); + Detection det; + det.location = { + int(boxes[idx].x * impl_->img_scale_[i].first), + int(boxes[idx].y * impl_->img_scale_[i].second), + int(boxes[idx].width * impl_->img_scale_[i].first), + int(boxes[idx].height * impl_->img_scale_[i].second)}; + det.confidence = scores[idx]; + det.species = c; + dets.push_back(std::move(det)); } } impl_->img_scale_.clear(); diff --git a/deploy/src/find_ruler.cc b/deploy/src/find_ruler.cc index 1eb9138..e76b55d 100644 --- a/deploy/src/find_ruler.cc +++ b/deploy/src/find_ruler.cc @@ -28,7 +28,7 @@ namespace find_ruler { class RulerMaskFinder::RulerMaskFinderImpl { public: /// Stores and processes the model. - detail::Model model_; + detail::ImageModel model_; }; RulerMaskFinder::RulerMaskFinder() : impl_(new RulerMaskFinderImpl()) {} @@ -62,9 +62,9 @@ ErrorCode RulerMaskFinder::Process(std::vector* masks) { // Run the model. std::vector outputs; ErrorCode status = impl_->model_.Process( - &outputs, "input_1", - {"output_node0:0"}); + {"output_node0:0"}, + &outputs); if (status != kSuccess) return status; // Copy model outputs into mask images. diff --git a/deploy/src/model.cc b/deploy/src/model.cc index 9c4608a..376d4bd 100644 --- a/deploy/src/model.cc +++ b/deploy/src/model.cc @@ -26,81 +26,140 @@ namespace detail { namespace tf = tensorflow; -Model::Model() - : session_(nullptr), - width_(0), - height_(0), - initialized_(false), - preprocessed_(), - mutex_() { -} +/// Implementation details for Model. +class Model::ModelImpl { + public: + /// Tensorflow session. + std::unique_ptr session_; + + /// Input size. + std::vector input_size_; + + /// Indicates whether the model has been initialized. + bool initialized_; +}; + +/// Implementation details for ImageModel. +class ImageModel::ImageModelImpl { + public: + /// Model object. + Model model_; + + /// Image width. + int width_; + + /// Image height. + int height_; + + /// Queue of futures containing preprocessed images. + std::queue> preprocessed_; + + /// Mutex for handling concurrent access to image queue. + std::mutex mutex_; +}; + +Model::Model() : impl_(new ModelImpl()) {} + +Model::~Model() {} ErrorCode Model::Init( const std::string& model_path, double gpu_fraction) { - initialized_ = false; + impl_->initialized_ = false; // Read in the graph. tf::GraphDef graph_def; tf::Status status = tf::ReadBinaryProto( tf::Env::Default(), model_path, &graph_def); if (!status.ok()) return kErrorLoadingModel; - + // Get graph input size. - ErrorCode status1 = InputSize( - graph_def, &(width_), &(height_)); + ErrorCode status1 = detail::InputSize(graph_def, &(impl_->input_size_)); if (status1 != kSuccess) return status1; - + // Create a new tensorflow session. tf::Session* session; status1 = GetSession(&session, gpu_fraction); if (status1 != kSuccess) return status1; - session_.reset(session); + impl_->session_.reset(session); // Create the tensorflow graph. - status = session_->Create(graph_def); + status = impl_->session_->Create(graph_def); if (!status.ok()) return kErrorTfGraph; - initialized_ = true; + impl_->initialized_ = true; return kSuccess; } -cv::Size Model::ImageSize() { - return cv::Size(width_, height_); +std::vector Model::InputSize() { + return impl_->input_size_; } -ErrorCode Model::AddImage( +bool Model::Initialized() { + return impl_->initialized_; +} + +ErrorCode Model::Process( + const tf::Tensor& input, + const std::string& input_name, + const std::vector& output_names, + std::vector* outputs) { + tf::Status status = impl_->session_->Run( + {{input_name, input}}, + output_names, + {}, + outputs); + if (!status.ok()) return kErrorRunSession; + return kSuccess; +} + +ImageModel::ImageModel() : impl_(new ImageModelImpl()) {} + +ImageModel::~ImageModel() {} + +ErrorCode ImageModel::Init( + const std::string& model_path, double gpu_fraction) { + // Do model initialization. + ErrorCode status = impl_->model_.Init(model_path, gpu_fraction); + if (status != kSuccess) return status; + + // Get graph input size. + ErrorCode status1 = detail::ImageSize( + impl_->model_.InputSize(), &(impl_->width_), &(impl_->height_)); + if (status != kSuccess) return status; + return kSuccess; +} + +cv::Size ImageModel::ImageSize() { + return cv::Size(impl_->width_, impl_->height_); +} + +ErrorCode ImageModel::AddImage( const cv::Mat& image, std::function preprocess) { - if (!initialized_) return kErrorBadInit; + if (!impl_->model_.Initialized()) return kErrorBadInit; if (!image.isContinuous()) return kErrorNotContinuous; - auto f = std::async(preprocess, image, width_, height_); - mutex_.lock(); - preprocessed_.push(std::move(f)); - mutex_.unlock(); + auto f = std::async(preprocess, image, impl_->width_, impl_->height_); + impl_->mutex_.lock(); + impl_->preprocessed_.push(std::move(f)); + impl_->mutex_.unlock(); return kSuccess; } -ErrorCode Model::Process( - std::vector* outputs, +ErrorCode ImageModel::Process( const std::string& input_name, - const std::vector& output_names) { - if (!initialized_) return kErrorBadInit; + const std::vector& output_names, + std::vector* outputs) { + if (!(impl_->model_.Initialized())) return kErrorBadInit; // Copy image queue contents into input tensor. - mutex_.lock(); + impl_->mutex_.lock(); tf::Tensor input = FutureQueueToTensor( - &preprocessed_, - width_, - height_); - mutex_.unlock(); + &impl_->preprocessed_, + impl_->width_, + impl_->height_); + impl_->mutex_.unlock(); // Run the model. - tf::Status status = session_->Run( - {{input_name, input}}, - output_names, - {}, - outputs); - if (!status.ok()) return kErrorRunSession; - return kSuccess; + return impl_->model_.Process(input, input_name, output_names, outputs); } } // namespace detail diff --git a/deploy/src/util.cc b/deploy/src/util.cc index 2c633b8..b5893d6 100644 --- a/deploy/src/util.cc +++ b/deploy/src/util.cc @@ -42,21 +42,35 @@ ErrorCode GetSession(tf::Session** session, double gpu_fraction) { return kSuccess; } -ErrorCode InputSize(const tf::GraphDef& graph_def, int* width, int* height) { +ErrorCode InputSize( + const tf::GraphDef& graph_def, + std::vector* input_size) { bool found = false; for (auto p : graph_def.node(0).attr()) { if (p.first == "shape") { found = true; auto shape = p.second.shape(); - if (shape.dim_size() != 4) return kErrorGraphDims; - *width = static_cast(shape.dim(2).size()); - *height = static_cast(shape.dim(1).size()); + if (shape.dim_size() < 1) return kErrorGraphDims; + input_size->resize(shape.dim_size()); + for (int i = 0; i < input_size->size(); ++i) { + (*input_size)[i] = static_cast(shape.dim(i).size()); + } } } if (!found) return kErrorNoShape; return kSuccess; } +ErrorCode ImageSize( + const std::vector& input_size, + int* width, + int* height) { + if (input_size.size() != 4) return kErrorGraphDims; + *width = input_size[2]; + *height = input_size[1]; + return kSuccess; +} + tf::Tensor ImageToTensor(const Image& image, const tf::TensorShape& shape) { tf::Tensor tensor(tf::DT_FLOAT, shape); auto flat = tensor.flat(); diff --git a/examples/deploy/cc/classify.cc b/examples/deploy/cc/classify.cc index 089ff15..6fadd97 100644 --- a/examples/deploy/cc/classify.cc +++ b/examples/deploy/cc/classify.cc @@ -65,30 +65,30 @@ int main(int argc, char* argv[]) { } // Process the loaded images. - std::vector> scores; - status = classifier.Process(&scores); + std::vector classifications; + status = classifier.Process(&classifications); if (status != em::kSuccess) { std::cout << "Error when attempting to do classification!" << std::endl; return -1; } // Display the images and print scores to console. - for (int i = 0; i < scores.size(); ++i) { - const std::vector& score = scores[i]; + for (int i = 0; i < classifications.size(); ++i) { + const cl::Classification& classification = classifications[i]; std::cout << "*******************************************" << std::endl; std::cout << "Fish cover scores:" << std::endl; - std::cout << "No fish: " << score[0] << std::endl; - std::cout << "Hand over fish: " << score[1] << std::endl; - std::cout << "Fish clear: " << score[2] << std::endl; + std::cout << "No fish: " << classification.cover[0] << std::endl; + std::cout << "Hand over fish: " << classification.cover[1] << std::endl; + std::cout << "Fish clear: " << classification.cover[2] << std::endl; std::cout << "*******************************************" << std::endl; std::cout << "Fish species scores:" << std::endl; - std::cout << "Fourspot: " << score[3] << std::endl; - std::cout << "Grey sole: " << score[4] << std::endl; - std::cout << "Other: " << score[5] << std::endl; - std::cout << "Plaice: " << score[6] << std::endl; - std::cout << "Summer: " << score[7] << std::endl; - std::cout << "Windowpane: " << score[8] << std::endl; - std::cout << "Winter: " << score[9] << std::endl; + std::cout << "Fourspot: " << classification.species[0] << std::endl; + std::cout << "Grey sole: " << classification.species[1] << std::endl; + std::cout << "Other: " << classification.species[2] << std::endl; + std::cout << "Plaice: " << classification.species[3] << std::endl; + std::cout << "Summer: " << classification.species[4] << std::endl; + std::cout << "Windowpane: " << classification.species[5] << std::endl; + std::cout << "Winter: " << classification.species[6] << std::endl; std::cout << std::endl; imgs[i].Show(); } diff --git a/examples/deploy/cc/detect.cc b/examples/deploy/cc/detect.cc index 8c1c193..9a69f9f 100644 --- a/examples/deploy/cc/detect.cc +++ b/examples/deploy/cc/detect.cc @@ -63,7 +63,7 @@ int main(int argc, char* argv[]) { } // Process the loaded images. - std::vector>> detections; + std::vector> detections; status = detector.Process(&detections); if (status != em::kSuccess) { std::cout << "Error when attempting to do detection!" << std::endl; @@ -73,13 +73,13 @@ int main(int argc, char* argv[]) { // Display the detections on the image. for (int i = 0; i < detections.size(); ++i) { em::Image img = imgs[i]; - std::vector> dets = detections[i]; + std::vector dets = detections[i]; if (dets.size() == 0) { std::cout << "No detections found for image " << i << std::endl; continue; } for (auto det : dets) { - img.DrawRect(det); + img.DrawRect(det.location); } img.Show(); } diff --git a/examples/deploy/cc/video.cc b/examples/deploy/cc/video.cc index c4b2201..4cce393 100644 --- a/examples/deploy/cc/video.cc +++ b/examples/deploy/cc/video.cc @@ -17,10 +17,12 @@ #include #include +#include #include "find_ruler.h" #include "detect.h" #include "classify.h" +#include "count.h" #include "video.h" // Declare namespace alias for shorthand. @@ -52,8 +54,21 @@ em::ErrorCode DetectAndClassify( const std::string& vid_path, const em::Rect& roi, const std::vector& transform, - std::vector>* detections, - std::vector>>* scores); + std::vector>* detections, + std::vector>* scores); + +/// Writes a csv file containing fish species and frame numbers. +/// @param count_path Path to count model file. +/// @param out_path Path to output csv file. +/// @param roi Region of interest, needed for image width and height. +/// @param detections Detections for each frame. +/// @param scores Cover and species scores for each detection. +em::ErrorCode WriteCounts( + const std::string& count_path, + const std::string& out_path, + const em::Rect& roi, + const std::vector>& detections, + const std::vector>& scores); /// Writes a new video with bounding boxes around detections. /// @param vid_path Path to the original video. @@ -67,21 +82,22 @@ em::ErrorCode WriteVideo( const std::string& out_path, const em::Rect& roi, const std::vector& transform, - const std::vector>& detections, - const std::vector>>& scores); + const std::vector>& detections, + const std::vector>& scores); int main(int argc, char* argv[]) { // Check input arguments. - if (argc < 5) { + if (argc < 6) { std::cout << "Expected at least four arguments: " << std::endl; std::cout << " Path to pb file with find_ruler model." << std::endl; std::cout << " Path to pb file with detect model." << std::endl; std::cout << " Path to pb file with classify model." << std::endl; + std::cout << " Path to pb file with count model." << std::endl; std::cout << " Path to one or more video files." << std::endl; } - for (int vid_idx = 4; vid_idx < argc; ++vid_idx) { + for (int vid_idx = 5; vid_idx < argc; ++vid_idx) { // Find the roi. std::cout << "Finding region of interest..." << std::endl; em::Rect roi; @@ -91,8 +107,8 @@ int main(int argc, char* argv[]) { // Find detections and classify them. std::cout << "Performing detection and classification..." << std::endl; - std::vector> detections; - std::vector>> scores; + std::vector> detections; + std::vector> scores; status = DetectAndClassify( argv[2], argv[3], @@ -103,10 +119,21 @@ int main(int argc, char* argv[]) { &scores); if (status != em::kSuccess) return -1; + // Write fish counts to file. + std::cout << "Writing counts to file..." << std::endl; + std::stringstream ss1; + ss1 << "fish_counts_" << vid_idx - 5 << ".csv"; + status = WriteCounts( + argv[4], + ss1.str(), + roi, + detections, + scores); + // Write annotated video to file. std::cout << "Writing video to file..." << std::endl; std::stringstream ss; - ss << "annotated_video_" << vid_idx - 4 << ".avi"; + ss << "annotated_video_" << vid_idx - 5 << ".avi"; status = WriteVideo( argv[vid_idx], ss.str(), @@ -194,8 +221,8 @@ em::ErrorCode DetectAndClassify( const std::string& vid_path, const em::Rect& roi, const std::vector& transform, - std::vector>* detections, - std::vector>>* scores) { + std::vector>* detections, + std::vector>* scores) { // Determined by experimentation with GPU having 8GB memory. static const int kMaxImg = 32; @@ -228,7 +255,7 @@ em::ErrorCode DetectAndClassify( while (true) { // Find detections. - std::vector> dets; + std::vector> dets; std::vector imgs; for (int i = 0; i < kMaxImg; ++i) { em::Image img; @@ -255,9 +282,11 @@ em::ErrorCode DetectAndClassify( // Classify detections. for (int i = 0; i < dets.size(); ++i) { - std::vector> score_batch; + std::vector score_batch; for (int j = 0; j < dets[i].size(); ++j) { - em::Image det_img = em::detect::GetDetImage(imgs[i], dets[i][j]); + em::Image det_img = em::detect::GetDetImage( + imgs[i], + dets[i][j].location); status = classifier.AddImage(det_img); if (status != em::kSuccess) { std::cout << "Failed to add frame to classifier!" << std::endl; @@ -276,13 +305,57 @@ em::ErrorCode DetectAndClassify( return em::kSuccess; } +em::ErrorCode WriteCounts( + const std::string& count_path, + const std::string& out_path, + const em::Rect& roi, + const std::vector>& detections, + const std::vector>& scores) { + + // Create and initialize keyframe finder. + em::count::KeyframeFinder finder; + em::ErrorCode status = finder.Init(count_path, roi[2], roi[3]); + if (status != em::kSuccess) { + std::cout << "Failed to initialize keyframe finder!" << std::endl; + return status; + } + + // Process the keyframe finder. + std::vector keyframes; + status = finder.Process(scores, detections, &keyframes); + if (status != em::kSuccess) { + std::cout << "Failed to process keyframe finder!" << std::endl; + return status; + } + + // Write the keyframes out. + std::ofstream csv(out_path); + csv << "id,frame,species_index" << std::endl; + int id = 0; + for (auto i : keyframes) { + csv << id << "," << i << ","; + const auto& c = scores[i][0]; + float max_score = 0.0; + int species_index = 0; + for (int j = 0; j < c.species.size(); ++j) { + if (c.species[j] > max_score) { + max_score = c.species[j]; + species_index = j; + } + } + csv << species_index << std::endl; + id++; + } + return em::kSuccess; +} + em::ErrorCode WriteVideo( const std::string& vid_path, const std::string& out_path, const em::Rect& roi, const std::vector& transform, - const std::vector>& detections, - const std::vector>>& scores) { + const std::vector>& detections, + const std::vector>& scores) { // Initialize the video reader. em::VideoReader reader; @@ -315,8 +388,8 @@ em::ErrorCode WriteVideo( frame.DrawRect(roi, {255, 0, 0}, 1, transform); for (int j = 0; j < detections[i].size(); ++j) { em::Color det_color; - double clear = scores[i][j][2]; - double hand = scores[i][j][1]; + double clear = scores[i][j].cover[2]; + double hand = scores[i][j].cover[1]; if (j == 0) { if (clear > hand) { frame.DrawText("Clear", {0, 0}, {0, 255, 0}); @@ -326,7 +399,7 @@ em::ErrorCode WriteVideo( det_color = {0, 0, 255}; } } - frame.DrawRect(detections[i][j], det_color, 2, transform, roi); + frame.DrawRect(detections[i][j].location, det_color, 2, transform, roi); } status = writer.AddFrame(frame); if (status != em::kSuccess) { diff --git a/examples/deploy/csharp/CMakeLists.txt b/examples/deploy/csharp/CMakeLists.txt index 445a0bf..da683b3 100644 --- a/examples/deploy/csharp/CMakeLists.txt +++ b/examples/deploy/csharp/CMakeLists.txt @@ -10,8 +10,13 @@ set(SWIG_SRC ${PROJECT_BINARY_DIR}/deploy/bindings/VectorUint8.cs ${PROJECT_BINARY_DIR}/deploy/bindings/VectorRect.cs ${PROJECT_BINARY_DIR}/deploy/bindings/VectorImage.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/VectorDetection.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/VectorVectorDetection.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/VectorClassification.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/VectorVectorClassification.cs ${PROJECT_BINARY_DIR}/deploy/bindings/VectorFloat.cs ${PROJECT_BINARY_DIR}/deploy/bindings/VectorDouble.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/VectorInt.cs ${PROJECT_BINARY_DIR}/deploy/bindings/SWIGTYPE_p_void.cs ${PROJECT_BINARY_DIR}/deploy/bindings/SWIGTYPE_p_unsigned_char.cs ${PROJECT_BINARY_DIR}/deploy/bindings/RulerMaskFinder.cs @@ -20,11 +25,15 @@ set(SWIG_SRC ${PROJECT_BINARY_DIR}/deploy/bindings/openemPINVOKE.cs ${PROJECT_BINARY_DIR}/deploy/bindings/openem.cs ${PROJECT_BINARY_DIR}/deploy/bindings/Image.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/Detection.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/Classification.cs ${PROJECT_BINARY_DIR}/deploy/bindings/ErrorCode.cs ${PROJECT_BINARY_DIR}/deploy/bindings/Detector.cs ${PROJECT_BINARY_DIR}/deploy/bindings/Color.cs ${PROJECT_BINARY_DIR}/deploy/bindings/Codec.cs - ${PROJECT_BINARY_DIR}/deploy/bindings/Classifier.cs) + ${PROJECT_BINARY_DIR}/deploy/bindings/ArrayFloat3.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/Classifier.cs + ${PROJECT_BINARY_DIR}/deploy/bindings/KeyframeFinder.cs) add_custom_target(swig_source_gen COMMAND "" DEPENDS openem_cs diff --git a/examples/deploy/csharp/classify.cs b/examples/deploy/csharp/classify.cs index e75117e..4319089 100644 --- a/examples/deploy/csharp/classify.cs +++ b/examples/deploy/csharp/classify.cs @@ -67,7 +67,7 @@ static int Main(string[] args) { } // Process the loaded images. - VectorVectorFloat scores = new VectorVectorFloat(); + VectorClassification scores = new VectorClassification(); status = classifier.Process(scores); if (status != ErrorCode.kSuccess) { Console.WriteLine("Failed to process images!"); @@ -78,18 +78,18 @@ static int Main(string[] args) { for (int i = 0; i < scores.Count; ++i) { Console.WriteLine("*******************************************"); Console.WriteLine("Fish cover scores:"); - Console.WriteLine("No fish: {0}", scores[i][0]); - Console.WriteLine("Hand over fish: {0}", scores[i][1]); - Console.WriteLine("Fish clear: {0}", scores[i][2]); + Console.WriteLine("No fish: {0}", scores[i].cover[0]); + Console.WriteLine("Hand over fish: {0}", scores[i].cover[1]); + Console.WriteLine("Fish clear: {0}", scores[i].cover[2]); Console.WriteLine("*******************************************"); Console.WriteLine("Fish species scores:"); - Console.WriteLine("Fourspot: {0}", scores[i][3]); - Console.WriteLine("Grey sole: {0}", scores[i][4]); - Console.WriteLine("Other: {0}", scores[i][5]); - Console.WriteLine("Plaice: {0}", scores[i][6]); - Console.WriteLine("Summer: {0}", scores[i][7]); - Console.WriteLine("Windowpane: {0}", scores[i][8]); - Console.WriteLine("Winter: {0}", scores[i][9]); + Console.WriteLine("Fourspot: {0}", scores[i].species[0]); + Console.WriteLine("Grey sole: {0}", scores[i].species[1]); + Console.WriteLine("Other: {0}", scores[i].species[2]); + Console.WriteLine("Plaice: {0}", scores[i].species[3]); + Console.WriteLine("Summer: {0}", scores[i].species[4]); + Console.WriteLine("Windowpane: {0}", scores[i].species[5]); + Console.WriteLine("Winter: {0}", scores[i].species[6]); Console.WriteLine(""); imgs[i].Show(); } diff --git a/examples/deploy/csharp/detect.cs b/examples/deploy/csharp/detect.cs index d5061d1..0275009 100644 --- a/examples/deploy/csharp/detect.cs +++ b/examples/deploy/csharp/detect.cs @@ -65,7 +65,7 @@ static int Main(string[] args) { } // Process the loaded images. - VectorVectorRect detections = new VectorVectorRect(); + VectorVectorDetection detections = new VectorVectorDetection(); status = detector.Process(detections); if (status != ErrorCode.kSuccess) { Console.WriteLine("Failed to process images!"); @@ -75,7 +75,7 @@ static int Main(string[] args) { // Display the detections on the image. for (int i = 0; i < detections.Count; ++i) { foreach (var det in detections[i]) { - imgs[i].DrawRect(det); + imgs[i].DrawRect(det.location); } imgs[i].Show(); } diff --git a/examples/deploy/csharp/video.cs b/examples/deploy/csharp/video.cs index 0730542..4182a4e 100644 --- a/examples/deploy/csharp/video.cs +++ b/examples/deploy/csharp/video.cs @@ -103,14 +103,14 @@ static void DetectAndClassify( string vid_path, Rect roi, VectorDouble transform, - out VectorVectorRect detections, - out VectorVectorVectorFloat scores) { + out VectorVectorDetection detections, + out VectorVectorClassification scores) { // Determined by experimentation with GPU having 8GB memory. const int kMaxImg = 32; // Initialize the outputs. - detections = new VectorVectorRect(); - scores = new VectorVectorVectorFloat(); + detections = new VectorVectorDetection(); + scores = new VectorVectorClassification(); // Create and initialize the detector. Detector detector = new Detector(); @@ -138,7 +138,7 @@ static void DetectAndClassify( while (true) { // Find detections. - VectorVectorRect dets = new VectorVectorRect(); + VectorVectorDetection dets = new VectorVectorDetection(); VectorImage imgs = new VectorImage(); for (int i = 0; i < kMaxImg; ++i) { Image img = new Image(); @@ -167,9 +167,9 @@ static void DetectAndClassify( // Classify detections. for (int i = 0; i < dets.Count; ++i) { - VectorVectorFloat score_batch = new VectorVectorFloat(); + VectorClassification score_batch = new VectorClassification(); for (int j = 0; j < dets[i].Count; ++j) { - Image det_img = openem.GetDetImage(imgs[i], dets[i][j]); + Image det_img = openem.GetDetImage(imgs[i], dets[i][j].location); status = classifier.AddImage(det_img); if (status != ErrorCode.kSuccess) { throw new Exception("Failed to add frame to classifier!"); @@ -185,6 +185,56 @@ static void DetectAndClassify( } } + /// + /// Writes a csv file containing fish species and frame numbers. + /// + /// Path to model file. + /// Path to output csv file. + /// Region of interest, needed for image dims. + /// Detections for each frame. + /// Cover and species scores for each detection. + static void WriteCounts( + string count_path, + string out_path, + Rect roi, + VectorVectorDetection detections, + VectorVectorClassification scores) { + + // Create and initialize keyframe finder. + KeyframeFinder finder = new KeyframeFinder(); + ErrorCode status = finder.Init(count_path, roi[2], roi[3]); + if (status != ErrorCode.kSuccess) { + throw new Exception("Failed to initialize keyframe finder!"); + } + + // Process keyframe finder. + VectorInt keyframes = new VectorInt(); + status = finder.Process(scores, detections, keyframes); + if (status != ErrorCode.kSuccess) { + throw new Exception("Failed to process keyframe finder!"); + } + + // Write the keyframes out. + using (var csv = new System.IO.StreamWriter(out_path)) { + csv.WriteLine("id,frame,species_index"); + int id = 0; + foreach (var i in keyframes) { + Classification c = scores[i][0]; + float max_score = 0.0F; + int species_index = 0; + for (int j = 0; j < c.species.Count; ++j) { + if (c.species[j] > max_score) { + max_score = c.species[j]; + species_index = j; + } + } + var line = string.Format("{0},{1},{2}", id, i, species_index); + csv.WriteLine(line); + id++; + } + } + } + /// /// Writes a new video with bounding boxes around detections. /// @@ -199,8 +249,8 @@ static void WriteVideo( string out_path, Rect roi, VectorDouble transform, - VectorVectorRect detections, - VectorVectorVectorFloat scores) { + VectorVectorDetection detections, + VectorVectorClassification scores) { // Initialize the video reader. VideoReader reader = new VideoReader(); @@ -233,8 +283,8 @@ static void WriteVideo( frame.DrawRect(roi, blue, 1, transform); for (int j = 0; j < detections[i].Count; ++j) { Color det_color = red; - double clear = scores[i][j][2]; - double hand = scores[i][j][1]; + double clear = scores[i][j].cover[2]; + double hand = scores[i][j].cover[1]; if (j == 0) { if (clear > hand) { frame.DrawText("Clear", new PairIntInt(0, 0), green); @@ -244,7 +294,7 @@ static void WriteVideo( det_color = red; } } - frame.DrawRect(detections[i][j], det_color, 2, transform, roi); + frame.DrawRect(detections[i][j].location, det_color, 2, transform, roi); } status = writer.AddFrame(frame); if (status != ErrorCode.kSuccess) { @@ -259,15 +309,16 @@ static void WriteVideo( static void Main(string[] args) { // Check input arguments. - if (args.Length < 4) { + if (args.Length < 5) { Console.WriteLine("Expected at least four arguments: "); Console.WriteLine(" Path to pb file with find_ruler model."); Console.WriteLine(" Path to pb file with detect model."); Console.WriteLine(" Path to pb file with classify model."); + Console.WriteLine(" Path to pb file with count model."); Console.WriteLine(" Path to one or more video files."); } - for (int vid_idx = 3; vid_idx < args.Length; ++vid_idx) { + for (int vid_idx = 4; vid_idx < args.Length; ++vid_idx) { // Find the roi. Console.WriteLine("Finding region of interest..."); Rect roi; @@ -276,8 +327,8 @@ static void Main(string[] args) { // Find detections and classify them. Console.WriteLine("Performing detection and classification..."); - VectorVectorRect detections; - VectorVectorVectorFloat scores; + VectorVectorDetection detections; + VectorVectorClassification scores; DetectAndClassify( args[1], args[2], @@ -287,11 +338,20 @@ static void Main(string[] args) { out detections, out scores); + // Count fish and write to csv file. + Console.WriteLine("Counting individuals from sequences..."); + WriteCounts( + args[3], + String.Format("fish_counts_{0}.csv", vid_idx - 4), + roi, + detections, + scores); + // Write annotated video to file. Console.WriteLine("Writing video to file..."); WriteVideo( args[vid_idx], - String.Format("annotated_video_{0}.avi", vid_idx - 3), + String.Format("annotated_video_{0}.avi", vid_idx - 4), roi, transform, detections, diff --git a/examples/deploy/python/classify.py b/examples/deploy/python/classify.py index ac89f26..5f21014 100644 --- a/examples/deploy/python/classify.py +++ b/examples/deploy/python/classify.py @@ -55,7 +55,7 @@ raise RuntimeError("Failed to add image for processing!") # Process the loaded images. - scores = openem.VectorVectorFloat() + scores = openem.VectorClassification() status = classifier.Process(scores) if not status == openem.kSuccess: raise RuntimeError("Failed to process images!") @@ -64,18 +64,18 @@ for img, s in zip(imgs, scores): print("*******************************************") print("Fish cover scores:") - print("No fish: {}".format(s[0])) - print("Hand over fish: {}".format(s[1])) - print("Fish clear: {}".format(s[2])) + print("No fish: {}".format(s.cover[0])) + print("Hand over fish: {}".format(s.cover[1])) + print("Fish clear: {}".format(s.cover[2])) print("*******************************************") print("Fish species scores:") - print("Fourspot: {}".format(s[3])) - print("Grey sole: {}".format(s[4])) - print("Other: {}".format(s[5])) - print("Plaice: {}".format(s[6])) - print("Summer: {}".format(s[7])) - print("Windowpane: {}".format(s[8])) - print("Winter: {}".format(s[9])) + print("Fourspot: {}".format(s.species[0])) + print("Grey sole: {}".format(s.species[1])) + print("Other: {}".format(s.species[2])) + print("Plaice: {}".format(s.species[3])) + print("Summer: {}".format(s.species[4])) + print("Windowpane: {}".format(s.species[5])) + print("Winter: {}".format(s.species[6])) print("") img.Show() diff --git a/examples/deploy/python/detect.py b/examples/deploy/python/detect.py index 5962c33..efc7f19 100644 --- a/examples/deploy/python/detect.py +++ b/examples/deploy/python/detect.py @@ -53,7 +53,7 @@ raise RuntimeError("Failed to add image for processing!") # Process the loaded images. - detections = openem.VectorVectorRect() + detections = openem.VectorVectorDetection() status = detector.Process(detections) if not status == openem.kSuccess: raise RuntimeError("Failed to process images!") @@ -61,6 +61,6 @@ # Display the detections on the image. for dets, img in zip(detections, imgs): for det in dets: - img.DrawRect(det) + img.DrawRect(det.location) img.Show() diff --git a/examples/deploy/python/video.py b/examples/deploy/python/video.py index f2bbe53..af063fd 100644 --- a/examples/deploy/python/video.py +++ b/examples/deploy/python/video.py @@ -129,7 +129,7 @@ def detect_and_classify(detect_path, classify_path, vid_path, roi, transform): while True: # Find detections. - dets = openem.VectorVectorRect() + dets = openem.VectorVectorDetection() imgs = [openem.Image() for _ in range(max_img)] for i, img in enumerate(imgs): status = reader.GetFrame(img) @@ -149,9 +149,9 @@ def detect_and_classify(detect_path, classify_path, vid_path, roi, transform): # Classify detections for det_frame, img in zip(dets, imgs): - score_batch = openem.VectorVectorFloat() + score_batch = openem.VectorClassification() for det in det_frame: - det_img = openem.GetDetImage(img, det) + det_img = openem.GetDetImage(img, det.location) status = classifier.AddImage(det_img) if not status == openem.kSuccess: raise RuntimeError("Failed to add frame to classifier!") @@ -163,6 +163,43 @@ def detect_and_classify(detect_path, classify_path, vid_path, roi, transform): break return (detections, scores) +def write_counts(count_path, out_path, roi, detections, scores): + """Writes a csv file containing fish species and frame numbers. + + # Arguments + count_path: Path to count model file. + out_path: Path to output csv file. + roi: Region of interest, needed for image width and height. + detections: Detections for each frame. + scores: Cover and species scores for each detection. + """ + # Create and initialize keyframe finder. + finder = openem.KeyframeFinder() + status = finder.Init(count_path, roi[2], roi[3]) + if not status == openem.kSuccess: + raise IOError("Failed to initialize keyframe finder!") + + # Process keyframe finder. + keyframes = openem.VectorInt() + status = finder.Process(scores, detections, keyframes) + if not status == openem.kSuccess: + raise RuntimeError("Failed to process keyframe finder!") + + # Write the keyframes out. + with open(out_path, "w") as csv: + csv.write("id,frame,species_index\n") + uid = 0 + for i in keyframes: + c = scores[i][0] + max_score = 0.0 + species_index = 0 + for j, s in enumerate(c.species): + if s > max_score: + max_score = s + species_index = j + csv.write("{},{},{}\n".format(uid, i, species_index)) + uid += 1 + def write_video(vid_path, out_path, roi, transform, detections, scores): """Writes a new video with bounding boxes around detections. @@ -198,8 +235,8 @@ def write_video(vid_path, out_path, roi, transform, detections, scores): raise RuntimeError("Error retrieving video frame!") frame.DrawRect(roi, (255, 0, 0), 1, transform) for j, (det, score) in enumerate(zip(det_frame, score_frame)): - clear = score[2] - hand = score[1] + clear = score.cover[2] + hand = score.cover[1] if j == 0: if clear > hand: frame.DrawText("Clear", (0, 0), (0, 255, 0)) @@ -207,7 +244,7 @@ def write_video(vid_path, out_path, roi, transform, detections, scores): else: frame.DrawText("Hand", (0, 0), (0, 0, 255)) det_color = (0, 0, 255) - frame.DrawRect(det, det_color, 2, transform, roi) + frame.DrawRect(det.location, det_color, 2, transform, roi) status = writer.AddFrame(frame) if not status == openem.kSuccess: raise RuntimeError("Error adding frame to video!") @@ -224,6 +261,9 @@ def write_video(vid_path, out_path, roi, transform, detections, scores): parser.add_argument("classify_model", type=str, help="Path to pb file with classify model.") + parser.add_argument("count_model", + type=str, + help="Path to pb file with count model.") parser.add_argument("video_paths", type=str, nargs="+", @@ -243,6 +283,15 @@ def write_video(vid_path, out_path, roi, transform, detections, scores): roi, transform) + # Write counts to csv. + print("Writing counts to csv...") + write_counts( + args.count_model, + "fish_counts_{}.csv".format(i), + roi, + detections, + scores) + # Write annotated video to file. print("Writing video to file...") write_video( diff --git a/examples/deploy/run_all.py b/examples/deploy/run_all.py index fbaf4ac..b61cbfb 100644 --- a/examples/deploy/run_all.py +++ b/examples/deploy/run_all.py @@ -13,7 +13,7 @@ def find_model(base_path, ex): def run_example(lang, ex, base_path): exe_path = os.path.join(os.getcwd(), lang, ex + EXE_EXTENSIONS[lang]) if ex == "video": - models = ["find_ruler", "detect", "classify"] + models = ["find_ruler", "detect", "classify", "count"] model_paths = [find_model(base_path, m) for m in models] inputs = glob.glob(os.path.join(base_path, "deploy", ex, "*.mp4")) else: