Skip to content

Commit

Permalink
Feature: Have oneshot as a Runtime implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Anthony Griffon <[email protected]>
  • Loading branch information
Miaxos committed Feb 23, 2024
1 parent c641db3 commit 4d1ed28
Show file tree
Hide file tree
Showing 28 changed files with 339 additions and 129 deletions.
11 changes: 8 additions & 3 deletions cluster_benchmark/tests/benchmark/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use openraft::Entry;
use openraft::EntryPayload;
use openraft::LogId;
use openraft::OptionalSend;
use openraft::OptionalSync;
use openraft::RaftLogId;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
Expand Down Expand Up @@ -225,8 +224,14 @@ impl RaftLogStorage<TypeConfig> for Arc<LogStore> {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> + Send {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
{
let mut log = self.log.write().await;
log.extend(entries.into_iter().map(|entry| (entry.get_log_id().index, entry)));
Expand Down
12 changes: 9 additions & 3 deletions examples/memstore/src/log_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,14 @@ impl<C: RaftTypeConfig> LogStoreInner<C> {
Ok(self.vote)
}

async fn append<I>(&mut self, entries: I, callback: LogFlushed<C::NodeId>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::AsyncRuntime, C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
{
// Simple implementation that calls the flush-before-return `append_to_log`.
for entry in entries {
self.log.insert(entry.get_log_id().index, entry);
Expand Down Expand Up @@ -191,7 +197,7 @@ mod impl_log_store {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::NodeId>,
callback: LogFlushed<C::AsyncRuntime, C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
Expand Down
10 changes: 8 additions & 2 deletions examples/raft-kv-memstore-opendal-snapshot-data/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,14 @@ impl RaftLogStorage<TypeConfig> for Arc<LogStore> {
}

#[tracing::instrument(level = "trace", skip(self, entries, callback))]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>>,
{
// Simple implementation that calls the flush-before-return `append_to_log`.
let mut log = self.log.lock().unwrap();
for entry in entries {
Expand Down
10 changes: 8 additions & 2 deletions examples/raft-kv-memstore-singlethreaded/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,14 @@ impl RaftLogStorage<TypeConfig> for Rc<LogStore> {
}

#[tracing::instrument(level = "trace", skip(self, entries, callback))]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>>,
{
// Simple implementation that calls the flush-before-return `append_to_log`.
let mut log = self.log.borrow_mut();
for entry in entries {
Expand Down
7 changes: 6 additions & 1 deletion examples/raft-kv-rocksdb/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use openraft::LogId;
use openraft::OptionalSend;
use openraft::RaftLogReader;
use openraft::RaftSnapshotBuilder;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
use openraft::StorageError;
use openraft::StorageIOError;
Expand Down Expand Up @@ -436,7 +437,11 @@ impl RaftLogStorage<TypeConfig> for LogStore {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> StorageResult<()>
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> StorageResult<()>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
I::IntoIter: Send,
Expand Down
58 changes: 58 additions & 0 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Type of a `oneshot` sender.
type OneshotSender<T: OptionalSend>: AsyncOneshotSendExt<T> + OptionalSend + OptionalSync + Debug + Sized;

type OneshotReceiverError: std::error::Error + OptionalSend;

/// Type of a `oneshot` receiver.
type OneshotReceiver<T: OptionalSend>: OptionalSend
+ OptionalSync
+ Future<Output = Result<T, Self::OneshotReceiverError>>
+ Unpin;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand Down Expand Up @@ -72,12 +83,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

/// Creates a new one-shot channel for sending single values.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
/// handle is used by the producer to send the value. The `Receiver` handle is
/// used by the consumer to receive the value.
///
/// Each handle can be used on separate tasks.
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend;
}

/// `Tokio` is the default asynchronous executor.
#[derive(Debug, Default)]
pub struct TokioRuntime;

pub struct TokioSendWrapper<T: OptionalSend>(pub tokio::sync::oneshot::Sender<T>);

impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
Expand All @@ -86,6 +109,9 @@ impl AsyncRuntime for TokioRuntime {
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;
type OneshotSender<T: OptionalSend> = TokioSendWrapper<T>;
type OneshotReceiver<T: OptionalSend> = tokio::sync::oneshot::Receiver<T>;
type OneshotReceiverError = tokio::sync::oneshot::error::RecvError;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -132,4 +158,36 @@ impl AsyncRuntime for TokioRuntime {
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}

#[inline]
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend {
let (tx, rx) = tokio::sync::oneshot::channel();
(TokioSendWrapper(tx), rx)
}
}

pub trait AsyncOneshotSendExt<T>: Unpin {
/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
///
/// This method consumes `self` as only one value may ever be sent on a `oneshot`
/// channel. It is not marked async because sending a message to an `oneshot`
/// channel never requires any form of waiting. Because of this, the `send`
/// method can be used in both synchronous and asynchronous code without
/// problems.
fn send(self, t: T) -> Result<(), T>;
}

impl<T: OptionalSend> AsyncOneshotSendExt<T> for TokioSendWrapper<T> {
#[inline]
fn send(self, t: T) -> Result<(), T> {
self.0.send(t)
}
}

impl<T: OptionalSend> Debug for TokioSendWrapper<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").finish()
}
}
41 changes: 29 additions & 12 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::balancer::Balancer;
Expand Down Expand Up @@ -215,7 +215,10 @@ where
SM: RaftStateMachine<C>,
{
/// The main loop of the Raft protocol.
pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
pub(crate) async fn main(
mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main(rx_shutdown).instrument(span).await;

Expand All @@ -239,7 +242,10 @@ where
}

#[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(
&mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

self.engine.startup();
Expand Down Expand Up @@ -432,7 +438,7 @@ where
&mut self,
changes: ChangeMembers<C::NodeId, C::Node>,
retain: bool,
tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
) {
let res = self.engine.state.membership_state.change_handler().apply(changes, retain);
let new_membership = match res {
Expand Down Expand Up @@ -593,7 +599,7 @@ where
pub(crate) fn handle_initialize(
&mut self,
member_nodes: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, (), InitializeError<C::NodeId, C::Node>>,
) {
tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!());

Expand All @@ -616,8 +622,12 @@ where

/// Reject a request due to the Raft node being in a state which prohibits the request.
#[tracing::instrument(level = "trace", skip(self, tx))]
pub(crate) fn reject_with_forward_to_leader<T, E>(&self, tx: ResultSender<T, E>)
where E: From<ForwardToLeader<C::NodeId, C::Node>> {
pub(crate) fn reject_with_forward_to_leader<T: OptionalSend, E: OptionalSend>(
&self,
tx: ResultSender<AsyncRuntimeOf<C>, T, E>,
) where
E: From<ForwardToLeader<C::NodeId, C::Node>>,
{
let mut leader_id = self.current_leader();
let leader_node = self.get_leader_node(leader_id);

Expand Down Expand Up @@ -680,7 +690,7 @@ where
{
tracing::debug!("append_to_log");

let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();
let callback = LogFlushed::new(Some(last_log_id), tx);
self.log_store.append(entries, callback).await?;
rx.await
Expand Down Expand Up @@ -865,7 +875,10 @@ where

/// Run an event handling loop
#[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))]
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(
&mut self,
mut rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
// Ratio control the ratio of number of RaftMsg to process to number of Notify to process.
let mut balancer = Balancer::new(10_000);

Expand Down Expand Up @@ -1067,7 +1080,11 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C::NodeId>) {
pub(super) fn handle_vote_request(
&mut self,
req: VoteRequest<C::NodeId>,
tx: VoteTx<AsyncRuntimeOf<C>, C::NodeId>,
) {
tracing::info!(req = display(req.summary()), func = func_name!());

let resp = self.engine.handle_vote_req(req);
Expand All @@ -1081,7 +1098,7 @@ where
pub(super) fn handle_append_entries_request(
&mut self,
req: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
tx: AppendEntriesTx<AsyncRuntimeOf<C>, C::NodeId>,
) {
tracing::debug!(req = display(req.summary()), func = func_name!());

Expand Down Expand Up @@ -1657,7 +1674,7 @@ where

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
Expand Down
5 changes: 4 additions & 1 deletion openraft/src/core/raft_msg/external_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::fmt;

use crate::core::raft_msg::ResultSender;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::RaftTypeConfig;
use crate::Snapshot;

Expand All @@ -23,7 +24,9 @@ pub(crate) enum ExternalCommand<C: RaftTypeConfig> {
Snapshot,

/// Get a snapshot from the state machine, send back via a oneshot::Sender.
GetSnapshot { tx: ResultSender<Option<Snapshot<C>>> },
GetSnapshot {
tx: ResultSender<AsyncRuntimeOf<C>, Option<Snapshot<C>>>,
},

/// Purge logs covered by a snapshot up to a specified index.
///
Expand Down
Loading

0 comments on commit 4d1ed28

Please sign in to comment.