diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 85dac7392..4381f7e98 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -822,6 +822,7 @@ where network, snapshot_network, self.log_store.get_log_reader().await, + self.sm_handle.new_snapshot_reader(), self.tx_notify.clone(), tracing::span!(parent: &self.span, Level::DEBUG, "replication", id=display(self.id), target=display(target)), ) @@ -1674,21 +1675,10 @@ where let _ = node.tx_repl.send(Replicate::logs(RequestId::new_append_entries(id), log_id_range)); } Inflight::Snapshot { id, last_log_id } => { - let _ = last_log_id; - - // Create a channel to let state machine worker to send the snapshot and the replication - // worker to receive it. - let (tx, rx) = C::AsyncRuntime::oneshot(); - - let cmd = sm::Command::get_snapshot(tx); - self.sm_handle - .send(cmd) - .map_err(|e| StorageIOError::read_snapshot(None, AnyError::error(e)))?; - // unwrap: The replication channel must not be dropped or it is a bug. - node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), rx)).map_err(|_e| { - StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")) - })?; + node.tx_repl.send(Replicate::snapshot(RequestId::new_snapshot(id), last_log_id)).map_err( + |_e| StorageIOError::read_snapshot(None, AnyError::error("replication channel closed")), + )?; } } } else { diff --git a/openraft/src/core/raft_msg/mod.rs b/openraft/src/core/raft_msg/mod.rs index 72ddfbc7e..f72e90409 100644 --- a/openraft/src/core/raft_msg/mod.rs +++ b/openraft/src/core/raft_msg/mod.rs @@ -13,7 +13,6 @@ use crate::raft::SnapshotResponse; use crate::raft::VoteRequest; use crate::raft::VoteResponse; use crate::type_config::alias::LogIdOf; -use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; use crate::type_config::alias::SnapshotDataOf; use crate::ChangeMembers; @@ -27,8 +26,6 @@ pub(crate) mod external_command; /// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`. pub(crate) type ResultSender = OneshotSenderOf>; -pub(crate) type ResultReceiver = OneshotReceiverOf>; - /// TX for Vote Response pub(crate) type VoteTx = ResultSender>; diff --git a/openraft/src/core/sm/handle.rs b/openraft/src/core/sm/handle.rs index 9ba60f859..8a718663c 100644 --- a/openraft/src/core/sm/handle.rs +++ b/openraft/src/core/sm/handle.rs @@ -2,15 +2,18 @@ use tokio::sync::mpsc; -use crate::alias::JoinHandleOf; -use crate::core::sm::Command; +use crate::core::sm; +use crate::type_config::alias::AsyncRuntimeOf; +use crate::type_config::alias::JoinHandleOf; +use crate::AsyncRuntime; use crate::RaftTypeConfig; +use crate::Snapshot; /// State machine worker handle for sending command to it. pub(crate) struct Handle where C: RaftTypeConfig { - pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender>, + pub(in crate::core::sm) cmd_tx: mpsc::UnboundedSender>, #[allow(dead_code)] pub(in crate::core::sm) join_handle: JoinHandleOf, @@ -19,8 +22,63 @@ where C: RaftTypeConfig impl Handle where C: RaftTypeConfig { - pub(crate) fn send(&mut self, cmd: Command) -> Result<(), mpsc::error::SendError>> { + pub(crate) fn send(&mut self, cmd: sm::Command) -> Result<(), mpsc::error::SendError>> { tracing::debug!("sending command to state machine worker: {:?}", cmd); self.cmd_tx.send(cmd) } + + /// Create a [`SnapshotReader`] to get the current snapshot from the state machine. + pub(crate) fn new_snapshot_reader(&self) -> SnapshotReader { + SnapshotReader { + cmd_tx: self.cmd_tx.downgrade(), + } + } +} + +/// A handle for retrieving a snapshot from the state machine. +pub(crate) struct SnapshotReader +where C: RaftTypeConfig +{ + /// Weak command sender to the state machine worker. + /// + /// It is weak because the [`Worker`] watches the close event of this channel for shutdown. + /// + /// [`Worker`]: sm::worker::Worker + cmd_tx: mpsc::WeakUnboundedSender>, +} + +impl SnapshotReader +where C: RaftTypeConfig +{ + /// Get a snapshot from the state machine. + /// + /// If the state machine worker has shutdown, it will return an error. + /// If there is not snapshot available, it will return `Ok(None)`. + pub(crate) async fn get_snapshot(&self) -> Result>, &'static str> { + let (tx, rx) = AsyncRuntimeOf::::oneshot(); + + let cmd = sm::Command::get_snapshot(tx); + tracing::debug!("SnapshotReader sending command to sm::Worker: {:?}", cmd); + + let Some(cmd_tx) = self.cmd_tx.upgrade() else { + tracing::info!("failed to upgrade cmd_tx, sm::Worker may have shutdown"); + return Err("failed to upgrade cmd_tx, sm::Worker may have shutdown"); + }; + + // If fail to send command, cmd is dropped and tx will be dropped. + let _ = cmd_tx.send(cmd); + + let got = match rx.await { + Ok(x) => x, + Err(_e) => { + tracing::error!("failed to receive snapshot, sm::Worker may have shutdown"); + return Err("failed to receive snapshot, sm::Worker may have shutdown"); + } + }; + + // Safe unwrap(): error is Infallible. + let snapshot = got.unwrap(); + + Ok(snapshot) + } } diff --git a/openraft/src/core/sm/worker.rs b/openraft/src/core/sm/worker.rs index 079d3a1c5..1e69d886e 100644 --- a/openraft/src/core/sm/worker.rs +++ b/openraft/src/core/sm/worker.rs @@ -1,6 +1,5 @@ use tokio::sync::mpsc; -use crate::alias::JoinHandleOf; use crate::async_runtime::AsyncOneshotSendExt; use crate::core::notify::Notify; use crate::core::raft_msg::ResultSender; @@ -15,6 +14,7 @@ use crate::core::ApplyingEntry; use crate::display_ext::DisplayOptionExt; use crate::entry::RaftPayload; use crate::storage::RaftStateMachine; +use crate::type_config::alias::JoinHandleOf; use crate::AsyncRuntime; use crate::RaftLogId; use crate::RaftSnapshotBuilder; diff --git a/openraft/src/replication/mod.rs b/openraft/src/replication/mod.rs index 80d31975c..d437607ce 100644 --- a/openraft/src/replication/mod.rs +++ b/openraft/src/replication/mod.rs @@ -26,7 +26,7 @@ use tracing_futures::Instrument; use crate::config::Config; use crate::core::notify::Notify; -use crate::core::raft_msg::ResultReceiver; +use crate::core::sm::handle::SnapshotReader; use crate::display_ext::DisplayOptionExt; use crate::error::HigherVote; use crate::error::PayloadTooLarge; @@ -53,6 +53,7 @@ use crate::storage::Snapshot; use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::InstantOf; use crate::type_config::alias::JoinHandleOf; +use crate::type_config::alias::LogIdOf; use crate::AsyncRuntime; use crate::Instant; use crate::LogId; @@ -127,6 +128,9 @@ where /// The [`RaftLogStorage::LogReader`] interface. log_reader: LS::LogReader, + /// The handle to get a snapshot directly from state machine. + snapshot_reader: SnapshotReader, + /// The Raft's runtime config. config: Arc, @@ -163,6 +167,7 @@ where network: N::Network, snapshot_network: N::Network, log_reader: LS::LogReader, + snapshot_reader: SnapshotReader, tx_raft_core: mpsc::UnboundedSender>, span: tracing::Span, ) -> ReplicationHandle { @@ -185,6 +190,7 @@ where snapshot_state: None, backoff: None, log_reader, + snapshot_reader, config, committed, matching, @@ -697,21 +703,17 @@ where #[tracing::instrument(level = "info", skip_all)] async fn stream_snapshot( &mut self, - snapshot_rx: DataWithId>>>, + snapshot_req: DataWithId>>, ) -> Result>, ReplicationError> { - let request_id = snapshot_rx.request_id(); - let rx = snapshot_rx.into_data(); + let request_id = snapshot_req.request_id(); tracing::info!(request_id = display(request_id), "{}", func_name!()); - let snapshot = rx.await.map_err(|e| { - let io_err = StorageIOError::read_snapshot(None, AnyError::error(e)); - StorageError::IO { source: io_err } + let snapshot = self.snapshot_reader.get_snapshot().await.map_err(|reason| { + tracing::warn!(error = display(&reason), "failed to get snapshot from state machine"); + ReplicationClosed::new(reason) })?; - // Safe unwrap(): the error is Infallible, so it is safe to unwrap. - let snapshot = snapshot.unwrap(); - tracing::info!( "received snapshot: request_id={}; meta:{}", request_id, diff --git a/openraft/src/replication/request.rs b/openraft/src/replication/request.rs index df1b6aa13..92bc2d573 100644 --- a/openraft/src/replication/request.rs +++ b/openraft/src/replication/request.rs @@ -1,5 +1,7 @@ use std::fmt; +use crate::type_config::alias::LogIdOf; + /// A replication request sent by RaftCore leader state to replication stream. #[derive(Debug)] pub(crate) enum Replicate @@ -22,8 +24,8 @@ where C: RaftTypeConfig Self::Data(Data::new_logs(id, log_id_range)) } - pub(crate) fn snapshot(id: RequestId, snapshot_rx: ResultReceiver>>) -> Self { - Self::Data(Data::new_snapshot(id, snapshot_rx)) + pub(crate) fn snapshot(id: RequestId, last_log_id: Option>) -> Self { + Self::Data(Data::new_snapshot(id, last_log_id)) } pub(crate) fn new_data(data: Data) -> Self { @@ -49,7 +51,6 @@ where C: RaftTypeConfig } } -use crate::core::raft_msg::ResultReceiver; use crate::display_ext::DisplayOptionExt; use crate::error::Fatal; use crate::error::StreamingError; @@ -61,7 +62,6 @@ use crate::type_config::alias::InstantOf; use crate::LogId; use crate::MessageSummary; use crate::RaftTypeConfig; -use crate::Snapshot; use crate::SnapshotMeta; /// Request to replicate a chunk of data, logs or snapshot. @@ -74,7 +74,7 @@ where C: RaftTypeConfig { Heartbeat, Logs(DataWithId>), - Snapshot(DataWithId>>>), + Snapshot(DataWithId>>), SnapshotCallback(DataWithId>), } @@ -143,8 +143,8 @@ where C: RaftTypeConfig Self::Logs(DataWithId::new(request_id, log_id_range)) } - pub(crate) fn new_snapshot(request_id: RequestId, snapshot_rx: ResultReceiver>>) -> Self { - Self::Snapshot(DataWithId::new(request_id, snapshot_rx)) + pub(crate) fn new_snapshot(request_id: RequestId, last_log_id: Option>) -> Self { + Self::Snapshot(DataWithId::new(request_id, last_log_id)) } pub(crate) fn new_snapshot_callback(