Skip to content

Commit

Permalink
Refactor: ReplicationCore get a snapshot directly from state machin…
Browse files Browse the repository at this point in the history
…e, via `SnapshotReader`
  • Loading branch information
drmingdrmer committed Mar 25, 2024
1 parent bb75f7b commit 49f7eae
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 37 deletions.
18 changes: 4 additions & 14 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ where
network,
snapshot_network,
self.log_store.get_log_reader().await,
self.sm_handle.new_snapshot_reader(),
self.tx_notify.clone(),
tracing::span!(parent: &self.span, Level::DEBUG, "replication", id=display(self.id), target=display(target)),
)
Expand Down Expand Up @@ -1674,21 +1675,10 @@ where
let _ = node.tx_repl.send(Replicate::logs(RequestId::new_append_entries(id), log_id_range));
}
Inflight::Snapshot { id, last_log_id } => {
let _ = last_log_id;

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
.send(cmd)
.map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?;

// unwrap: The replication channel must not be dropped or it is a bug.
node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), rx)).map_err(|_e| {
StorageIOError::read_snapshot(None, AnyError::error("replication channel closed"))
})?;
node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), last_log_id)).map_err(
|_e| StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")),
)?;
}
}
} else {
Expand Down
3 changes: 0 additions & 3 deletions openraft/src/core/raft_msg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::raft::SnapshotResponse;
use crate::raft::VoteRequest;
use crate::raft::VoteResponse;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::ChangeMembers;
Expand All @@ -27,8 +26,6 @@ pub(crate) mod external_command;
/// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`.
pub(crate) type ResultSender<C, T, E = Infallible> = OneshotSenderOf<C, Result<T, E>>;

pub(crate) type ResultReceiver<C, T, E = Infallible> = OneshotReceiverOf<C, Result<T, E>>;

/// TX for Vote Response
pub(crate) type VoteTx<C> = ResultSender<C, VoteResponse<C>>;

Expand Down
66 changes: 62 additions & 4 deletions openraft/src/core/sm/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
use tokio::sync::mpsc;

use crate::alias::JoinHandleOf;
use crate::core::sm::Command;
use crate::core::sm;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::JoinHandleOf;
use crate::AsyncRuntime;
use crate::RaftTypeConfig;
use crate::Snapshot;

/// State machine worker handle for sending command to it.
pub(crate) struct Handle<C>
where C: RaftTypeConfig
{
pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender<Command<C>>,
pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender<sm::Command<C>>,

#[allow(dead_code)]
pub(in crate::core::sm) join_handle: JoinHandleOf<C, ()>,
Expand All @@ -19,8 +22,63 @@ where C: RaftTypeConfig
impl<C> Handle<C>
where C: RaftTypeConfig
{
pub(crate) fn send(&mut self, cmd: Command<C>) -> Result<(), mpsc::error::SendError<Command<C>>> {
pub(crate) fn send(&mut self, cmd: sm::Command<C>) -> Result<(), mpsc::error::SendError<sm::Command<C>>> {
tracing::debug!("sending command to state machine worker: {:?}", cmd);
self.cmd_tx.send(cmd)
}

/// Create a [`SnapshotReader`] to get the current snapshot from the state machine.
pub(crate) fn new_snapshot_reader(&self) -> SnapshotReader<C> {
SnapshotReader {
cmd_tx: self.cmd_tx.downgrade(),
}
}
}

/// A handle for retrieving a snapshot from the state machine.
pub(crate) struct SnapshotReader<C>
where C: RaftTypeConfig
{
/// Weak command sender to the state machine worker.
///
/// It is weak because the [`Worker`] watches the close event of this channel for shutdown.
///
/// [`Worker`]: sm::worker::Worker
cmd_tx: mpsc::WeakUnboundedSender<sm::Command<C>>,
}

impl<C> SnapshotReader<C>
where C: RaftTypeConfig
{
/// Get a snapshot from the state machine.
///
/// If the state machine worker has shutdown, it will return an error.
/// If there is not snapshot available, it will return `Ok(None)`.
pub(crate) async fn get_snapshot(&self) -> Result<Option<Snapshot<C>>, &'static str> {
let (tx, rx) = AsyncRuntimeOf::<C>::oneshot();

let cmd = sm::Command::get_snapshot(tx);
tracing::debug!("SnapshotReader sending command to sm::Worker: {:?}", cmd);

let Some(cmd_tx) = self.cmd_tx.upgrade() else {
tracing::info!("failed to upgrade cmd_tx, sm::Worker may have shutdown");
return Err("failed to upgrade cmd_tx, sm::Worker may have shutdown");
};

// If fail to send command, cmd is dropped and tx will be dropped.
let _ = cmd_tx.send(cmd);

let got = match rx.await {
Ok(x) => x,
Err(_e) => {
tracing::error!("failed to receive snapshot, sm::Worker may have shutdown");
return Err("failed to receive snapshot, sm::Worker may have shutdown");
}
};

// Safe unwrap(): error is Infallible.
let snapshot = got.unwrap();

Ok(snapshot)
}
}
20 changes: 11 additions & 9 deletions openraft/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use tracing_futures::Instrument;

use crate::config::Config;
use crate::core::notify::Notify;
use crate::core::raft_msg::ResultReceiver;
use crate::core::sm::handle::SnapshotReader;
use crate::display_ext::DisplayOptionExt;
use crate::error::HigherVote;
use crate::error::PayloadTooLarge;
Expand All @@ -53,6 +53,7 @@ use crate::storage::Snapshot;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::InstantOf;
use crate::type_config::alias::JoinHandleOf;
use crate::type_config::alias::LogIdOf;
use crate::AsyncRuntime;
use crate::Instant;
use crate::LogId;
Expand Down Expand Up @@ -127,6 +128,9 @@ where
/// The [`RaftLogStorage::LogReader`] interface.
log_reader: LS::LogReader,

/// The handle to get a snapshot directly from state machine.
snapshot_reader: SnapshotReader<C>,

/// The Raft's runtime config.
config: Arc<Config>,

Expand Down Expand Up @@ -163,6 +167,7 @@ where
network: N::Network,
snapshot_network: N::Network,
log_reader: LS::LogReader,
snapshot_reader: SnapshotReader<C>,
tx_raft_core: mpsc::UnboundedSender<Notify<C>>,
span: tracing::Span,
) -> ReplicationHandle<C> {
Expand All @@ -185,6 +190,7 @@ where
snapshot_state: None,
backoff: None,
log_reader,
snapshot_reader,
config,
committed,
matching,
Expand Down Expand Up @@ -697,21 +703,17 @@ where
#[tracing::instrument(level = "info", skip_all)]
async fn stream_snapshot(
&mut self,
snapshot_rx: DataWithId<ResultReceiver<C, Option<Snapshot<C>>>>,
snapshot_rx: DataWithId<Option<LogIdOf<C>>>,
) -> Result<Option<Data<C>>, ReplicationError<C>> {
let request_id = snapshot_rx.request_id();
let rx = snapshot_rx.into_data();

tracing::info!(request_id = display(request_id), "{}", func_name!());

let snapshot = rx.await.map_err(|e| {
let io_err = StorageIOError::read_snapshot(None, AnyError::error(e));
StorageError::IO { source: io_err }
let snapshot = self.snapshot_reader.get_snapshot().await.map_err(|reason| {
tracing::warn!(error = display(&reason), "failed to get snapshot from state machine");
ReplicationClosed::new(reason)
})?;

// Safe unwrap(): the error is Infallible, so it is safe to unwrap.
let snapshot = snapshot.unwrap();

tracing::info!(
"received snapshot: request_id={}; meta:{}",
request_id,
Expand Down
14 changes: 7 additions & 7 deletions openraft/src/replication/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use crate::type_config::alias::LogIdOf;

/// A replication request sent by RaftCore leader state to replication stream.
#[derive(Debug)]
pub(crate) enum Replicate<C>
Expand All @@ -22,8 +24,8 @@ where C: RaftTypeConfig
Self::Data(Data::new_logs(id, log_id_range))
}

pub(crate) fn snapshot(id: RequestId, snapshot_rx: ResultReceiver<C, Option<Snapshot<C>>>) -> Self {
Self::Data(Data::new_snapshot(id, snapshot_rx))
pub(crate) fn snapshot(id: RequestId, last_log_id: Option<LogIdOf<C>>) -> Self {
Self::Data(Data::new_snapshot(id, last_log_id))
}

pub(crate) fn new_data(data: Data<C>) -> Self {
Expand All @@ -49,7 +51,6 @@ where C: RaftTypeConfig
}
}

use crate::core::raft_msg::ResultReceiver;
use crate::display_ext::DisplayOptionExt;
use crate::error::Fatal;
use crate::error::StreamingError;
Expand All @@ -61,7 +62,6 @@ use crate::type_config::alias::InstantOf;
use crate::LogId;
use crate::MessageSummary;
use crate::RaftTypeConfig;
use crate::Snapshot;
use crate::SnapshotMeta;

/// Request to replicate a chunk of data, logs or snapshot.
Expand All @@ -74,7 +74,7 @@ where C: RaftTypeConfig
{
Heartbeat,
Logs(DataWithId<LogIdRange<C::NodeId>>),
Snapshot(DataWithId<ResultReceiver<C, Option<Snapshot<C>>>>),
Snapshot(DataWithId<Option<LogIdOf<C>>>),
SnapshotCallback(DataWithId<SnapshotCallback<C>>),
}

Expand Down Expand Up @@ -143,8 +143,8 @@ where C: RaftTypeConfig
Self::Logs(DataWithId::new(request_id, log_id_range))
}

pub(crate) fn new_snapshot(request_id: RequestId, snapshot_rx: ResultReceiver<C, Option<Snapshot<C>>>) -> Self {
Self::Snapshot(DataWithId::new(request_id, snapshot_rx))
pub(crate) fn new_snapshot(request_id: RequestId, last_log_id: Option<LogIdOf<C>>) -> Self {
Self::Snapshot(DataWithId::new(request_id, last_log_id))
}

pub(crate) fn new_snapshot_callback(
Expand Down

0 comments on commit 49f7eae

Please sign in to comment.