Skip to content

Commit

Permalink
feat: Add stop to WorkQueue (#1600)
Browse files Browse the repository at this point in the history
For #442.
  • Loading branch information
kuznetsss authored Aug 14, 2024
1 parent 0ff1eda commit 5499b89
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 42 deletions.
69 changes: 69 additions & 0 deletions src/rpc/WorkQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,40 @@

#include "rpc/WorkQueue.hpp"

#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
#include "util/prometheus/Label.hpp"
#include "util/prometheus/Prometheus.hpp"

#include <boost/json/object.hpp>

#include <cstddef>
#include <cstdint>
#include <functional>
#include <thread>
#include <utility>

namespace rpc {

void
WorkQueue::OneTimeCallable::setCallable(std::function<void()> func)
{
func_ = func;
}

void
WorkQueue::OneTimeCallable::operator()()
{
if (not called_) {
func_();
called_ = true;
}
}
WorkQueue::OneTimeCallable::operator bool() const
{
return func_.operator bool();
}

WorkQueue::WorkQueue(std::uint32_t numWorkers, uint32_t maxSize)
: queued_{PrometheusService::counterInt(
"work_queue_queued_total_number",
Expand Down Expand Up @@ -53,10 +80,52 @@ WorkQueue::~WorkQueue()
join();
}

void
WorkQueue::stop(std::function<void()> onQueueEmpty)
{
auto handler = onQueueEmpty_.lock();
handler->setCallable(std::move(onQueueEmpty));
stopping_ = true;
if (size() == 0) {
handler->operator()();
}
}

WorkQueue
WorkQueue::make_WorkQueue(util::Config const& config)
{
static util::Logger const log{"RPC"};
auto const serverConfig = config.section("server");
auto const numThreads = config.valueOr<uint32_t>("workers", std::thread::hardware_concurrency());
auto const maxQueueSize = serverConfig.valueOr<uint32_t>("max_queue_size", 0); // 0 is no limit

LOG(log.info()) << "Number of workers = " << numThreads << ". Max queue size = " << maxQueueSize;
return WorkQueue{numThreads, maxQueueSize};
}

boost::json::object
WorkQueue::report() const
{
auto obj = boost::json::object{};

obj["queued"] = queued_.get().value();
obj["queued_duration_us"] = durationUs_.get().value();
obj["current_queue_size"] = curSize_.get().value();
obj["max_queue_size"] = maxSize_;

return obj;
}

void
WorkQueue::join()
{
ioc_.join();
}

size_t
WorkQueue::size() const
{
return curSize_.get().value();
}

} // namespace rpc
71 changes: 49 additions & 22 deletions src/rpc/WorkQueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#pragma once

#include "util/Assert.hpp"
#include "util/Mutex.hpp"
#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
#include "util/prometheus/Counter.hpp"
Expand All @@ -30,11 +32,12 @@
#include <boost/json.hpp>
#include <boost/json/object.hpp>

#include <atomic>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <thread>

namespace rpc {

Expand All @@ -52,6 +55,23 @@ class WorkQueue {
util::Logger log_{"RPC"};
boost::asio::thread_pool ioc_;

std::atomic_bool stopping_;

class OneTimeCallable {
std::function<void()> func_;
bool called_{false};

public:
void
setCallable(std::function<void()> func);

void
operator()();

operator bool() const;
};
util::Mutex<OneTimeCallable> onQueueEmpty_;

public:
/**
* @brief Create an we instance of the work queue.
Expand All @@ -62,23 +82,22 @@ class WorkQueue {
WorkQueue(std::uint32_t numWorkers, uint32_t maxSize = 0);
~WorkQueue();

/**
* @brief Put the work queue into a stopping state. This will prevent new jobs from being queued.
*
* @param onQueueEmpty A callback to run when the last task in the queue is completed
*/
void
stop(std::function<void()> onQueueEmpty);

/**
* @brief A factory function that creates the work queue based on a config.
*
* @param config The Clio config to use
* @return The work queue
*/
static WorkQueue
make_WorkQueue(util::Config const& config)
{
static util::Logger const log{"RPC"};
auto const serverConfig = config.section("server");
auto const numThreads = config.valueOr<uint32_t>("workers", std::thread::hardware_concurrency());
auto const maxQueueSize = serverConfig.valueOr<uint32_t>("max_queue_size", 0); // 0 is no limit

LOG(log.info()) << "Number of workers = " << numThreads << ". Max queue size = " << maxQueueSize;
return WorkQueue{numThreads, maxQueueSize};
}
make_WorkQueue(util::Config const& config);

/**
* @brief Submit a job to the work queue.
Expand All @@ -94,6 +113,11 @@ class WorkQueue {
bool
postCoro(FnType&& func, bool isWhiteListed)
{
if (stopping_) {
LOG(log_.warn()) << "Queue is stopping, rejecting incoming task.";
return false;
}

if (curSize_.get().value() >= maxSize_ && !isWhiteListed) {
LOG(log_.warn()) << "Queue is full. rejecting job. current size = " << curSize_.get().value()
<< "; max size = " << maxSize_;
Expand All @@ -116,6 +140,11 @@ class WorkQueue {

func(yield);
--curSize_.get();
if (curSize_.get().value() == 0 && stopping_) {
auto onTasksComplete = onQueueEmpty_.lock();
ASSERT(onTasksComplete->operator bool(), "onTasksComplete must be set when stopping is true.");
onTasksComplete->operator()();
}
}
);

Expand All @@ -128,23 +157,21 @@ class WorkQueue {
* @return The report as a JSON object.
*/
boost::json::object
report() const
{
auto obj = boost::json::object{};

obj["queued"] = queued_.get().value();
obj["queued_duration_us"] = durationUs_.get().value();
obj["current_queue_size"] = curSize_.get().value();
obj["max_queue_size"] = maxSize_;

return obj;
}
report() const;

/**
* @brief Wait until all the jobs in the queue are finished.
*/
void
join();

/**
* @brief Get the size of the queue.
*
* @return The numver of jobs in the queue.
*/
size_t
size() const;
};

} // namespace rpc
76 changes: 56 additions & 20 deletions tests/unit/rpc/WorkQueueTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <condition_variable>
#include <cstdint>
#include <mutex>
#include <semaphore>

using namespace util;
using namespace rpc;
Expand All @@ -43,14 +44,14 @@ constexpr auto JSONConfig = R"JSON({
})JSON";
} // namespace

struct RPCWorkQueueTestBase : NoLoggerFixture {
struct WorkQueueTestBase : NoLoggerFixture {
Config cfg = Config{boost::json::parse(JSONConfig)};
WorkQueue queue = WorkQueue::make_WorkQueue(cfg);
};

struct RPCWorkQueueTest : WithPrometheus, RPCWorkQueueTestBase {};
struct WorkQueueTest : WithPrometheus, WorkQueueTestBase {};

TEST_F(RPCWorkQueueTest, WhitelistedExecutionCountAddsUp)
TEST_F(WorkQueueTest, WhitelistedExecutionCountAddsUp)
{
auto constexpr static TOTAL = 512u;
uint32_t executeCount = 0u;
Expand All @@ -77,7 +78,7 @@ TEST_F(RPCWorkQueueTest, WhitelistedExecutionCountAddsUp)
EXPECT_EQ(report.at("max_queue_size"), 2);
}

TEST_F(RPCWorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded)
TEST_F(WorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded)
{
auto constexpr static TOTAL = 3u;
auto expectedCount = 2u;
Expand Down Expand Up @@ -112,35 +113,70 @@ TEST_F(RPCWorkQueueTest, NonWhitelistedPreventSchedulingAtQueueLimitExceeded)
EXPECT_TRUE(unblocked);
}

struct RPCWorkQueueMockPrometheusTest : WithMockPrometheus, RPCWorkQueueTestBase {};
struct WorkQueueStopTest : WorkQueueTest {
testing::StrictMock<testing::MockFunction<void()>> onTasksComplete;
testing::StrictMock<testing::MockFunction<void()>> taskMock;
};

TEST_F(WorkQueueStopTest, RejectsNewTasksWhenStopping)
{
EXPECT_CALL(taskMock, Call());
EXPECT_TRUE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false));

queue.stop([]() {});
EXPECT_FALSE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false));

queue.join();
}

TEST_F(RPCWorkQueueMockPrometheusTest, postCoroCouhters)
TEST_F(WorkQueueStopTest, CallsOnTasksCompleteWhenStoppingAndQueueIsEmpty)
{
EXPECT_CALL(taskMock, Call());
EXPECT_TRUE(queue.postCoro([this](auto /* yield */) { taskMock.Call(); }, false));

EXPECT_CALL(onTasksComplete, Call()).WillOnce([&]() { EXPECT_EQ(queue.size(), 0u); });
queue.stop(onTasksComplete.AsStdFunction());
queue.join();
}
TEST_F(WorkQueueStopTest, CallsOnTasksCompleteWhenStoppingOnLastTask)
{
std::binary_semaphore semaphore{0};

EXPECT_CALL(taskMock, Call());
EXPECT_TRUE(queue.postCoro(
[&](auto /* yield */) {
taskMock.Call();
semaphore.acquire();
},
false
));

EXPECT_CALL(onTasksComplete, Call()).WillOnce([&]() { EXPECT_EQ(queue.size(), 0u); });
queue.stop(onTasksComplete.AsStdFunction());
semaphore.release();

queue.join();
}

struct WorkQueueMockPrometheusTest : WithMockPrometheus, WorkQueueTestBase {};

TEST_F(WorkQueueMockPrometheusTest, postCoroCouhters)
{
auto& queuedMock = makeMock<CounterInt>("work_queue_queued_total_number", "");
auto& durationMock = makeMock<CounterInt>("work_queue_cumulitive_tasks_duration_us", "");
auto& curSizeMock = makeMock<GaugeInt>("work_queue_current_size", "");

std::mutex mtx;
bool canContinue = false;
std::condition_variable cv;
std::binary_semaphore semaphore{0};

EXPECT_CALL(curSizeMock, value()).WillOnce(::testing::Return(0));
EXPECT_CALL(curSizeMock, value()).Times(2).WillRepeatedly(::testing::Return(0));
EXPECT_CALL(curSizeMock, add(1));
EXPECT_CALL(queuedMock, add(1));
EXPECT_CALL(durationMock, add(::testing::Gt(0))).WillOnce([&](auto) {
EXPECT_CALL(curSizeMock, add(-1));
std::unique_lock const lk{mtx};
canContinue = true;
cv.notify_all();
semaphore.release();
});

auto const res = queue.postCoro(
[&](auto /* yield */) {
std::unique_lock lk{mtx};
cv.wait(lk, [&]() { return canContinue; });
},
false
);
auto const res = queue.postCoro([&](auto /* yield */) { semaphore.acquire(); }, false);

ASSERT_TRUE(res);
queue.join();
Expand Down

0 comments on commit 5499b89

Please sign in to comment.