diff --git a/README.md b/README.md index 4340fa33a0..61a1f3f103 100644 --- a/README.md +++ b/README.md @@ -433,6 +433,17 @@ If you want to set the thread count at runtime, there is no convenient way... Bu svr.new_task_queue = [] { return new ThreadPool(12); }; ``` +You can also provide an optional parameter to limit the maximum number +of pending requests, i.e. requests `accept()`ed by the listener but +still waiting to be serviced by worker threads. + +```cpp +svr.new_task_queue = [] { return new ThreadPool(/*num_threads=*/12, /*max_queued_requests=*/18); }; +``` + +Default limit is 0 (unlimited). Once the limit is reached, the listener +will shutdown the client connection. + ### Override the default thread pool with yours You can supply your own thread pool implementation according to your need. @@ -444,8 +455,10 @@ public: pool_.start_with_thread_count(n); } - virtual void enqueue(std::function fn) override { - pool_.enqueue(fn); + virtual bool enqueue(std::function fn) override { + /* Return true if the task was actually enqueued, or false + * if the caller must drop the corresponding connection. */ + return pool_.enqueue(fn); } virtual void shutdown() override { diff --git a/httplib.h b/httplib.h index 5dfaeadec8..7cc6ed2577 100644 --- a/httplib.h +++ b/httplib.h @@ -582,7 +582,7 @@ class TaskQueue { TaskQueue() = default; virtual ~TaskQueue() = default; - virtual void enqueue(std::function fn) = 0; + virtual bool enqueue(std::function fn) = 0; virtual void shutdown() = 0; virtual void on_idle() {} @@ -590,7 +590,8 @@ class TaskQueue { class ThreadPool : public TaskQueue { public: - explicit ThreadPool(size_t n) : shutdown_(false) { + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { while (n) { threads_.emplace_back(worker(*this)); n--; @@ -600,13 +601,17 @@ class ThreadPool : public TaskQueue { ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; - void enqueue(std::function fn) override { + bool enqueue(std::function fn) override { { std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } jobs_.push_back(std::move(fn)); } cond_.notify_one(); + return true; } void shutdown() override { @@ -656,6 +661,7 @@ class ThreadPool : public TaskQueue { std::list> jobs_; bool shutdown_; + size_t max_queued_requests_ = 0; std::condition_variable cond_; std::mutex mutex_; @@ -6242,7 +6248,11 @@ inline bool Server::listen_internal() { #endif } - task_queue->enqueue([this, sock]() { process_and_close_socket(sock); }); + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } } task_queue->shutdown(); diff --git a/test/test.cc b/test/test.cc index 91513cd936..002debb9e1 100644 --- a/test/test.cc +++ b/test/test.cc @@ -6506,18 +6506,103 @@ TEST(SocketStream, is_writable_INET) { #endif // #ifndef _WIN32 TEST(TaskQueueTest, IncreaseAtomicInteger) { - static constexpr unsigned int number_of_task{1000000}; + static constexpr unsigned int number_of_tasks{1000000}; std::atomic_uint count{0}; std::unique_ptr task_queue{ new ThreadPool{CPPHTTPLIB_THREAD_POOL_COUNT}}; - for (unsigned int i = 0; i < number_of_task; ++i) { - task_queue->enqueue( + for (unsigned int i = 0; i < number_of_tasks; ++i) { + auto queued = task_queue->enqueue( [&count] { count.fetch_add(1, std::memory_order_relaxed); }); + EXPECT_TRUE(queued); + } + + EXPECT_NO_THROW(task_queue->shutdown()); + EXPECT_EQ(number_of_tasks, count.load()); +} + +TEST(TaskQueueTest, IncreaseAtomicIntegerWithQueueLimit) { + static constexpr unsigned int number_of_tasks{1000000}; + static constexpr unsigned int qlimit{2}; + unsigned int queued_count{0}; + std::atomic_uint count{0}; + std::unique_ptr task_queue{ + new ThreadPool{/*num_threads=*/1, qlimit}}; + + for (unsigned int i = 0; i < number_of_tasks; ++i) { + if (task_queue->enqueue( + [&count] { count.fetch_add(1, std::memory_order_relaxed); })) { + queued_count++; + } + } + + EXPECT_NO_THROW(task_queue->shutdown()); + EXPECT_EQ(queued_count, count.load()); + EXPECT_TRUE(queued_count <= number_of_tasks); + EXPECT_TRUE(queued_count >= qlimit); +} + +TEST(TaskQueueTest, MaxQueuedRequests) { + static constexpr unsigned int qlimit{3}; + std::unique_ptr task_queue{new ThreadPool{1, qlimit}}; + std::condition_variable sem_cv; + std::mutex sem_mtx; + int credits = 0; + bool queued; + + /* Fill up the queue with tasks that will block until we give them credits to + * complete. */ + for (unsigned int n = 0; n <= qlimit;) { + queued = task_queue->enqueue([&sem_mtx, &sem_cv, &credits] { + std::unique_lock lock(sem_mtx); + while (credits <= 0) { + sem_cv.wait(lock); + } + /* Consume the credit and signal the test code if they are all gone. */ + if (--credits == 0) { sem_cv.notify_one(); } + }); + + if (n < qlimit) { + /* The first qlimit enqueues must succeed. */ + EXPECT_TRUE(queued); + } else { + /* The last one will succeed only when the worker thread + * starts and dequeues the first blocking task. Although + * not necessary for the correctness of this test, we sleep for + * a short while to avoid busy waiting. */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + if (queued) { n++; } + } + + /* Further enqueues must fail since the queue is full. */ + for (auto i = 0; i < 4; i++) { + queued = task_queue->enqueue([] {}); + EXPECT_FALSE(queued); + } + + /* Give the credits to allow the previous tasks to complete. */ + { + std::unique_lock lock(sem_mtx); + credits += qlimit + 1; + } + sem_cv.notify_all(); + + /* Wait for all the credits to be consumed. */ + { + std::unique_lock lock(sem_mtx); + while (credits > 0) { + sem_cv.wait(lock); + } + } + + /* Check that we are able again to enqueue at least qlimit tasks. */ + for (unsigned int i = 0; i < qlimit; i++) { + queued = task_queue->enqueue([] {}); + EXPECT_TRUE(queued); } EXPECT_NO_THROW(task_queue->shutdown()); - EXPECT_EQ(number_of_task, count.load()); } TEST(RedirectTest, RedirectToUrlWithQueryParameters) {