diff --git a/ThreadPool.h b/ThreadPool.h index 4183203..52b4696 100644 --- a/ThreadPool.h +++ b/ThreadPool.h @@ -1,3 +1,7 @@ +// +// Created: https://github.com/progschj/ThreadPool +// Modified by: Martin Opat +// #ifndef THREAD_POOL_H #define THREAD_POOL_H @@ -15,60 +19,68 @@ class ThreadPool { public: ThreadPool(size_t); template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; ~ThreadPool(); + void waitForAll(); private: // need to keep track of threads so we can join them std::vector< std::thread > workers; // the task queue std::queue< std::function > tasks; - + // synchronization std::mutex queue_mutex; std::condition_variable condition; bool stop; + + std::atomic tasks_count{0}; + std::condition_variable all_tasks_done; }; - + // the constructor just launches some amount of workers inline ThreadPool::ThreadPool(size_t threads) - : stop(false) + : stop(false) { for(size_t i = 0;i task; - + for(;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); std::unique_lock lock(this->queue_mutex); - this->condition.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); + if (--tasks_count == 0) { + all_tasks_done.notify_all(); + } } - - task(); } - } ); } // add new work item to the pool template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> +auto ThreadPool::enqueue(F&& f, Args&&... args) +-> std::future::type> { using return_type = typename std::result_of::type; auto task = std::make_shared< std::packaged_task >( std::bind(std::forward(f), std::forward(args)...) - ); - + ); + std::future res = task->get_future(); { std::unique_lock lock(queue_mutex); @@ -78,6 +90,7 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) throw std::runtime_error("enqueue on stopped ThreadPool"); tasks.emplace([task](){ (*task)(); }); + ++tasks_count; } condition.notify_one(); return res; @@ -95,4 +108,9 @@ inline ThreadPool::~ThreadPool() worker.join(); } +inline void ThreadPool::waitForAll() { + std::unique_lock lock(queue_mutex); + all_tasks_done.wait(lock, [this](){ return tasks_count == 0 && tasks.empty(); }); +} + #endif