From 5499b892e656810889fcf79284e32e8a89211c0a Mon Sep 17 00:00:00 2001 From: Sergey Kuznetsov Date: Wed, 14 Aug 2024 12:00:13 +0100 Subject: [PATCH] feat: Add stop to WorkQueue (#1600) For #442. --- src/rpc/WorkQueue.cpp | 69 ++++++++++++++++++++++++++++ src/rpc/WorkQueue.hpp | 71 ++++++++++++++++++++--------- tests/unit/rpc/WorkQueueTests.cpp | 76 +++++++++++++++++++++++-------- 3 files changed, 174 insertions(+), 42 deletions(-) diff --git a/src/rpc/WorkQueue.cpp b/src/rpc/WorkQueue.cpp index 7d0df6eec..305c35b69 100644 --- a/src/rpc/WorkQueue.cpp +++ b/src/rpc/WorkQueue.cpp @@ -19,13 +19,40 @@ #include "rpc/WorkQueue.hpp" +#include "util/config/Config.hpp" +#include "util/log/Logger.hpp" #include "util/prometheus/Label.hpp" #include "util/prometheus/Prometheus.hpp" +#include + +#include #include +#include +#include +#include namespace rpc { +void +WorkQueue::OneTimeCallable::setCallable(std::function func) +{ + func_ = func; +} + +void +WorkQueue::OneTimeCallable::operator()() +{ + if (not called_) { + func_(); + called_ = true; + } +} +WorkQueue::OneTimeCallable::operator bool() const +{ + return func_.operator bool(); +} + WorkQueue::WorkQueue(std::uint32_t numWorkers, uint32_t maxSize) : queued_{PrometheusService::counterInt( "work_queue_queued_total_number", @@ -53,10 +80,52 @@ WorkQueue::~WorkQueue() join(); } +void +WorkQueue::stop(std::function onQueueEmpty) +{ + auto handler = onQueueEmpty_.lock(); + handler->setCallable(std::move(onQueueEmpty)); + stopping_ = true; + if (size() == 0) { + handler->operator()(); + } +} + +WorkQueue +WorkQueue::make_WorkQueue(util::Config const& config) +{ + static util::Logger const log{"RPC"}; + auto const serverConfig = config.section("server"); + auto const numThreads = config.valueOr("workers", std::thread::hardware_concurrency()); + auto const maxQueueSize = serverConfig.valueOr("max_queue_size", 0); // 0 is no limit + + LOG(log.info()) << "Number of workers = " << numThreads << ". Max queue size = " << maxQueueSize; + return WorkQueue{numThreads, maxQueueSize}; +} + +boost::json::object +WorkQueue::report() const +{ + auto obj = boost::json::object{}; + + obj["queued"] = queued_.get().value(); + obj["queued_duration_us"] = durationUs_.get().value(); + obj["current_queue_size"] = curSize_.get().value(); + obj["max_queue_size"] = maxSize_; + + return obj; +} + void WorkQueue::join() { ioc_.join(); } +size_t +WorkQueue::size() const +{ + return curSize_.get().value(); +} + } // namespace rpc diff --git a/src/rpc/WorkQueue.hpp b/src/rpc/WorkQueue.hpp index 70c6b7153..1f540cba0 100644 --- a/src/rpc/WorkQueue.hpp +++ b/src/rpc/WorkQueue.hpp @@ -19,6 +19,8 @@ #pragma once +#include "util/Assert.hpp" +#include "util/Mutex.hpp" #include "util/config/Config.hpp" #include "util/log/Logger.hpp" #include "util/prometheus/Counter.hpp" @@ -30,11 +32,12 @@ #include #include +#include #include +#include #include #include #include -#include namespace rpc { @@ -52,6 +55,23 @@ class WorkQueue { util::Logger log_{"RPC"}; boost::asio::thread_pool ioc_; + std::atomic_bool stopping_; + + class OneTimeCallable { + std::function func_; + bool called_{false}; + + public: + void + setCallable(std::function func); + + void + operator()(); + + operator bool() const; + }; + util::Mutex onQueueEmpty_; + public: /** * @brief Create an we instance of the work queue. @@ -62,6 +82,14 @@ class WorkQueue { WorkQueue(std::uint32_t numWorkers, uint32_t maxSize = 0); ~WorkQueue(); + /** + * @brief Put the work queue into a stopping state. This will prevent new jobs from being queued. + * + * @param onQueueEmpty A callback to run when the last task in the queue is completed + */ + void + stop(std::function onQueueEmpty); + /** * @brief A factory function that creates the work queue based on a config. * @@ -69,16 +97,7 @@ class WorkQueue { * @return The work queue */ static WorkQueue - make_WorkQueue(util::Config const& config) - { - static util::Logger const log{"RPC"}; - auto const serverConfig = config.section("server"); - auto const numThreads = config.valueOr("workers", std::thread::hardware_concurrency()); - auto const maxQueueSize = serverConfig.valueOr("max_queue_size", 0); // 0 is no limit - - LOG(log.info()) << "Number of workers = " << numThreads << ". Max queue size = " << maxQueueSize; - return WorkQueue{numThreads, maxQueueSize}; - } + make_WorkQueue(util::Config const& config); /** * @brief Submit a job to the work queue. @@ -94,6 +113,11 @@ class WorkQueue { bool postCoro(FnType&& func, bool isWhiteListed) { + if (stopping_) { + LOG(log_.warn()) << "Queue is stopping, rejecting incoming task."; + return false; + } + if (curSize_.get().value() >= maxSize_ && !isWhiteListed) { LOG(log_.warn()) << "Queue is full. rejecting job. current size = " << curSize_.get().value() << "; max size = " << maxSize_; @@ -116,6 +140,11 @@ class WorkQueue { func(yield); --curSize_.get(); + if (curSize_.get().value() == 0 && stopping_) { + auto onTasksComplete = onQueueEmpty_.lock(); + ASSERT(onTasksComplete->operator bool(), "onTasksComplete must be set when stopping is true."); + onTasksComplete->operator()(); + } } ); @@ -128,23 +157,21 @@ class WorkQueue { * @return The report as a JSON object. */ boost::json::object - report() const - { - auto obj = boost::json::object{}; - - obj["queued"] = queued_.get().value(); - obj["queued_duration_us"] = durationUs_.get().value(); - obj["current_queue_size"] = curSize_.get().value(); - obj["max_queue_size"] = maxSize_; - - return obj; - } + report() const; /** * @brief Wait until all the jobs in the queue are finished. */ void join(); + + /** + * @brief Get the size of the queue. + * + * @return The numver of jobs in the queue. + */ + size_t + size() const; }; } // namespace rpc diff --git a/tests/unit/rpc/WorkQueueTests.cpp b/tests/unit/rpc/WorkQueueTests.cpp index 637234f73..c976907f5 100644 --- a/tests/unit/rpc/WorkQueueTests.cpp +++ b/tests/unit/rpc/WorkQueueTests.cpp @@ -31,6 +31,7 @@ #include #include #include +#include using namespace util; using namespace rpc; @@ -43,14 +44,14 @@ constexpr auto JSONConfig = R"JSON({ })JSON"; } // namespace -struct RPCWorkQueueTestBase : NoLoggerFixture { +struct WorkQueueTestBase : NoLoggerFixture { Config cfg = Config{boost::json::parse(JSONConfig)}; WorkQueue queue = WorkQueue::make_WorkQueue(cfg); }; -struct RPCWorkQueueTest : WithPrometheus, RPCWorkQueueTestBase {}; +struct WorkQueueTest : WithPrometheus, WorkQueueTestBase {}; -TEST_F(RPCWorkQueueTest, WhitelistedExecutionCountAddsUp) +TEST_F(WorkQueueTest, WhitelistedExecutionCountAddsUp) { auto constexpr static TOTAL = 512u; uint32_t executeCount = 0u; @@ -77,7 +78,7 @@ TEST_F(RPCWorkQueueTest, WhitelistedExecutionCountAddsUp) EXPECT_EQ(report.at("max_queue_size"), 2); } -TEST_F(RPCWorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded) +TEST_F(WorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded) { auto constexpr static TOTAL = 3u; auto expectedCount = 2u; @@ -112,35 +113,70 @@ TEST_F(RPCWorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded) EXPECT_TRUE(unblocked); } -struct RPCWorkQueueMockPrometheusTest : WithMockPrometheus, RPCWorkQueueTestBase {}; +struct WorkQueueStopTest : WorkQueueTest { + testing::StrictMock> onTasksComplete; + testing::StrictMock> taskMock; +}; + +TEST_F(WorkQueueStopTest, RejectsNewTasksWhenStopping) +{ + EXPECT_CALL(taskMock, Call()); + EXPECT_TRUE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false)); + + queue.stop([]() {}); + EXPECT_FALSE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false)); + + queue.join(); +} -TEST_F(RPCWorkQueueMockPrometheusTest, postCoroCouhters) +TEST_F(WorkQueueStopTest, CallsOnTasksCompleteWhenStoppingAndQueueIsEmpty) +{ + EXPECT_CALL(taskMock, Call()); + EXPECT_TRUE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false)); + + EXPECT_CALL(onTasksComplete, Call()).WillOnce([&]() { EXPECT_EQ(queue.size(), 0u); }); + queue.stop(onTasksComplete.AsStdFunction()); + queue.join(); +} +TEST_F(WorkQueueStopTest, CallsOnTasksCompleteWhenStoppingOnLastTask) +{ + std::binary_semaphore semaphore{0}; + + EXPECT_CALL(taskMock, Call()); + EXPECT_TRUE(queue.postCoro( + [&](auto /* yield */) { + taskMock.Call(); + semaphore.acquire(); + }, + false + )); + + EXPECT_CALL(onTasksComplete, Call()).WillOnce([&]() { EXPECT_EQ(queue.size(), 0u); }); + queue.stop(onTasksComplete.AsStdFunction()); + semaphore.release(); + + queue.join(); +} + +struct WorkQueueMockPrometheusTest : WithMockPrometheus, WorkQueueTestBase {}; + +TEST_F(WorkQueueMockPrometheusTest, postCoroCouhters) { auto& queuedMock = makeMock("work_queue_queued_total_number", ""); auto& durationMock = makeMock("work_queue_cumulitive_tasks_duration_us", ""); auto& curSizeMock = makeMock("work_queue_current_size", ""); - std::mutex mtx; - bool canContinue = false; - std::condition_variable cv; + std::binary_semaphore semaphore{0}; - EXPECT_CALL(curSizeMock, value()).WillOnce(::testing::Return(0)); + EXPECT_CALL(curSizeMock, value()).Times(2).WillRepeatedly(::testing::Return(0)); EXPECT_CALL(curSizeMock, add(1)); EXPECT_CALL(queuedMock, add(1)); EXPECT_CALL(durationMock, add(::testing::Gt(0))).WillOnce([&](auto) { EXPECT_CALL(curSizeMock, add(-1)); - std::unique_lock const lk{mtx}; - canContinue = true; - cv.notify_all(); + semaphore.release(); }); - auto const res = queue.postCoro( - [&](auto /* yield */) { - std::unique_lock lk{mtx}; - cv.wait(lk, [&]() { return canContinue; }); - }, - false - ); + auto const res = queue.postCoro([&](auto /* yield */) { semaphore.acquire(); }, false); ASSERT_TRUE(res); queue.join();