From 70273deb41a2ad419fcd877287d51196b19b6a0c Mon Sep 17 00:00:00 2001 From: Sainath Singineedi Date: Thu, 12 Dec 2024 15:44:53 +0530 Subject: [PATCH] Fix: throw unreachable error if connection fails --- .../raft-kv-memstore-grpc/src/network/mod.rs | 46 ++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/examples/raft-kv-memstore-grpc/src/network/mod.rs b/examples/raft-kv-memstore-grpc/src/network/mod.rs index f26bedca0..e92ca9afb 100644 --- a/examples/raft-kv-memstore-grpc/src/network/mod.rs +++ b/examples/raft-kv-memstore-grpc/src/network/mod.rs @@ -1,6 +1,7 @@ use bincode::deserialize; use bincode::serialize; use openraft::error::NetworkError; +use openraft::error::Unreachable; use openraft::network::v2::RaftNetworkV2; use openraft::network::RPCOption; use openraft::raft::AppendEntriesRequest; @@ -24,8 +25,6 @@ use crate::TypeConfig; /// Provides the networking layer for Raft nodes to communicate with each other. pub struct Network {} -type RaftServiceClient = InternalServiceClient; - impl Network {} /// Implementation of the RaftNetworkFactory trait for creating new network connections. @@ -35,21 +34,20 @@ impl RaftNetworkFactory for Network { #[tracing::instrument(level = "debug", skip_all)] async fn new_client(&mut self, _: NodeId, node: &Node) -> Self::Network { - let channel = Channel::builder(format!("http://{}", node.rpc_addr).parse().unwrap()).connect().await.unwrap(); - NetworkConnection::new(InternalServiceClient::new(channel)) + NetworkConnection::new(node.clone()) } } /// Represents an active network connection to a remote Raft node. /// Handles serialization and deserialization of Raft messages over gRPC. pub struct NetworkConnection { - client: RaftServiceClient, + target_node: Node, } impl NetworkConnection { /// Creates a new NetworkConnection with the provided gRPC client. - pub fn new(client: RaftServiceClient) -> Self { - NetworkConnection { client } + pub fn new(target_node: Node) -> Self { + NetworkConnection { target_node } } } @@ -61,10 +59,18 @@ impl RaftNetworkV2 for NetworkConnection { req: AppendEntriesRequest, _option: RPCOption, ) -> Result, RPCError> { + let server_addr = self.target_node.rpc_addr.clone(); + let channel = match Channel::builder(format!("http://{}", server_addr).parse().unwrap()).connect().await { + Ok(channel) => channel, + Err(e) => { + return Err(openraft::error::RPCError::Unreachable(Unreachable::new(&e))); + } + }; + let mut client = InternalServiceClient::new(channel); + let value = serialize(&req).map_err(|e| RPCError::Network(NetworkError::new(&e)))?; let request = RaftRequestBytes { value }; - let response = - self.client.append_entries(request).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; + let response = client.append_entries(request).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; let message = response.into_inner(); let result = deserialize(&message.value).map_err(|e| RPCError::Network(NetworkError::new(&e)))?; Ok(result) @@ -77,6 +83,14 @@ impl RaftNetworkV2 for NetworkConnection { _cancel: impl std::future::Future + openraft::OptionalSend + 'static, _option: RPCOption, ) -> Result, crate::typ::StreamingError> { + let server_addr = self.target_node.rpc_addr.clone(); + let channel = match Channel::builder(format!("http://{}", server_addr).parse().unwrap()).connect().await { + Ok(channel) => channel, + Err(e) => { + return Err(openraft::error::RPCError::Unreachable(Unreachable::new(&e)).into()); + } + }; + let mut client = InternalServiceClient::new(channel); // Serialize the vote and snapshot metadata let rpc_meta = serialize(&(vote, snapshot.meta.clone())).map_err(|e| RPCError::Network(NetworkError::new(&e)))?; @@ -106,8 +120,7 @@ impl RaftNetworkV2 for NetworkConnection { let requests_stream = futures::stream::iter(requests); // Send the streaming snapshot request - let response = - self.client.snapshot(requests_stream).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; + let response = client.snapshot(requests_stream).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; let message = response.into_inner(); // Deserialize the response @@ -120,6 +133,15 @@ impl RaftNetworkV2 for NetworkConnection { req: VoteRequest, _option: RPCOption, ) -> Result, RPCError> { + let server_addr = self.target_node.rpc_addr.clone(); + let channel = match Channel::builder(format!("http://{}", server_addr).parse().unwrap()).connect().await { + Ok(channel) => channel, + Err(e) => { + return Err(openraft::error::RPCError::Unreachable(Unreachable::new(&e))); + } + }; + let mut client = InternalServiceClient::new(channel); + // Convert the openraft VoteRequest to protobuf VoteRequest let proto_vote_req: PbVoteRequest = req.into(); @@ -127,7 +149,7 @@ impl RaftNetworkV2 for NetworkConnection { let request = tonic::Request::new(proto_vote_req); // Send the vote request - let response = self.client.vote(request).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; + let response = client.vote(request).await.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; // Convert the response back to openraft VoteResponse let proto_vote_resp: PbVoteResponse = response.into_inner();