diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs index 955d24d9e9..f52595e8b3 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -5,11 +5,17 @@ use crate::cluster_topology::TopologyHash; use dashmap::DashMap; use futures::FutureExt; use rand::seq::IteratorRandom; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::sync::atomic::Ordering; use std::sync::Arc; use telemetrylib::Telemetry; +use tracing::debug; + +use tokio::sync::Notify; +use tokio::task::JoinHandle; + /// Count the number of connections in a connections_map object macro_rules! count_connections { ($conn_map:expr) => {{ @@ -134,11 +140,213 @@ impl std::fmt::Display for ConnectionsMap { } } +#[derive(Clone, Debug)] +pub(crate) struct RefreshTaskNotifier { + notify: Arc, +} + +impl RefreshTaskNotifier { + fn new() -> Self { + RefreshTaskNotifier { + notify: Arc::new(Notify::new()), + } + } + + pub fn get_notifier(&self) -> Arc { + self.notify.clone() + } + + pub fn notify(&self) { + self.notify.notify_waiters(); + } +} + +// Enum representing the task status during a connection refresh. +// +// - **Reconnecting**: +// Indicates that a refresh task is in progress. This status includes a dedicated +// notifier (`RefreshTaskNotifier`) so that other tasks can wait for the connection +// to be refreshed before proceeding. +// +// - **ReconnectingTooLong**: +// Represents a situation where a refresh task has taken too long to complete. +// The status transitions from `Reconnecting` to `ReconnectingTooLong` under specific +// conditions (e.g., after one attempt of reconnecting inside the task or after a timeout). +// +// The transition from `Reconnecting` to `ReconnectingTooLong` is managed exclusively +// within the `update_refreshed_connection` function in `poll_flush`. This ensures that +// all requests maintain a consistent view of the connections. +// +// When transitioning from `Reconnecting` to `ReconnectingTooLong`, the associated +// notifier is triggered to unblock all awaiting tasks. +#[derive(Clone, Debug)] +pub(crate) enum RefreshTaskStatus { + // The task is actively reconnecting. Includes a notifier for tasks to wait on. + Reconnecting(RefreshTaskNotifier), + // The task has exceeded the allowed reconnection time. + #[allow(dead_code)] + ReconnectingTooLong, +} + +impl Drop for RefreshTaskStatus { + fn drop(&mut self) { + if let RefreshTaskStatus::Reconnecting(notifier) = self { + debug!("RefreshTaskStatus: Dropped while in Reconnecting status. Notifying tasks."); + notifier.notify(); + } + } +} + +impl RefreshTaskStatus { + /// Creates a new `RefreshTaskStatus` in the `Reconnecting` status with a fresh `RefreshTaskNotifier`. + pub fn new() -> Self { + debug!("RefreshTaskStatus: Initialized in Reconnecting status with a new notifier."); + RefreshTaskStatus::Reconnecting(RefreshTaskNotifier::new()) + } + + // Transitions the current status from `Reconnecting` to `ReconnectingTooLong` in place. + // + // If the current status is `Reconnecting`, this method notifies all waiting tasks + // using the embedded `RefreshTaskNotifier` and updates the status to `ReconnectingTooLong`. + // + // If the status is already `ReconnectingTooLong`, this method does nothing. + #[allow(dead_code)] + pub fn flip_status_to_too_long(&mut self) { + if let RefreshTaskStatus::Reconnecting(notifier) = self { + debug!( + "RefreshTaskStatus: Notifying tasks before transitioning to ReconnectingTooLong." + ); + notifier.notify(); + *self = RefreshTaskStatus::ReconnectingTooLong; + } else { + debug!("RefreshTaskStatus: Already in ReconnectingTooLong status."); + } + } + + pub fn notify_waiting_requests(&mut self) { + if let RefreshTaskStatus::Reconnecting(notifier) = self { + debug!("RefreshTaskStatus::notify_waiting_requests notify"); + notifier.notify(); + } else { + debug!("RefreshTaskStatus::notify_waiting_requests - ReconnectingTooLong status."); + } + } +} + +// Struct combining the task handle and its status +#[derive(Debug)] +pub(crate) struct RefreshTaskState { + pub handle: JoinHandle<()>, + pub status: RefreshTaskStatus, +} + +impl RefreshTaskState { + // Creates a new `RefreshTaskState` with a `Reconnecting` state and a new notifier. + pub fn new(handle: JoinHandle<()>) -> Self { + debug!("RefreshTaskState: Creating a new instance with a Reconnecting state."); + RefreshTaskState { + handle, + status: RefreshTaskStatus::new(), + } + } +} + +impl Drop for RefreshTaskState { + fn drop(&mut self) { + if let RefreshTaskStatus::Reconnecting(ref notifier) = self.status { + debug!("RefreshTaskState: Dropped while in Reconnecting status. Notifying tasks."); + notifier.notify(); + } else { + debug!("RefreshTaskState: Dropped while in ReconnectingTooLong status."); + } + + // Abort the task handle if it's not yet finished + if !self.handle.is_finished() { + debug!("RefreshTaskState: Aborting unfinished task."); + self.handle.abort(); + } else { + debug!("RefreshTaskState: Task already finished, no abort necessary."); + } + } +} + +// This struct is used to track the status of each address refresh +pub(crate) struct RefreshConnectionStates { + // Holds all the failed addresses that started a refresh task. + pub(crate) refresh_addresses_started: HashSet, + // Follow the refresh ops on the connections + pub(crate) refresh_address_in_progress: HashMap, + // Holds all the refreshed addresses that are ready to be inserted into the connection_map + pub(crate) refresh_addresses_done: HashMap>>, +} + +impl RefreshConnectionStates { + // Clears all ongoing refresh connection tasks and resets associated state tracking. + // + // - This method removes all entries in the `refresh_address_in_progress` map. + // - The `Drop` trait is responsible for notifying the associated notifiers and aborting any unfinished refresh tasks. + // - Additionally, this method clears `refresh_addresses_started` and `refresh_addresses_done` + // to ensure no stale data remains in the refresh state tracking. + pub(crate) fn clear_refresh_state(&mut self) { + debug!( + "clear_refresh_state: removing all in-progress refresh connection tasks for addresses: {:?}", + self.refresh_address_in_progress.keys().collect::>() + ); + + // Clear the entire map; Drop handles the cleanup + self.refresh_address_in_progress.clear(); + + // Clear other state tracking + self.refresh_addresses_started.clear(); + self.refresh_addresses_done.clear(); + } + + // Collects the notifiers for the given addresses and returns them as a vector. + // + // This function retrieves the notifiers for the provided addresses from the `refresh_address_in_progress` + // map and returns them, so they can be awaited outside of the lock. + // + // # Arguments + // * `addresses` - A list of addresses for which notifiers are required. + // + // # Returns + // A vector of `futures::future::Notified` that can be awaited. + pub(crate) fn collect_refresh_notifiers( + &self, + addresses: &HashSet, + ) -> Vec> { + addresses + .iter() + .filter_map(|address| { + self.refresh_address_in_progress + .get(address) + .and_then(|refresh_state| match &refresh_state.status { + RefreshTaskStatus::Reconnecting(notifier) => { + Some(notifier.get_notifier().clone()) + } + _ => None, + }) + }) + .collect() + } +} + +impl Default for RefreshConnectionStates { + fn default() -> Self { + Self { + refresh_addresses_started: HashSet::new(), + refresh_address_in_progress: HashMap::new(), + refresh_addresses_done: HashMap::new(), + } + } +} + pub(crate) struct ConnectionsContainer { connection_map: DashMap>, pub(crate) slot_map: SlotMap, read_from_replica_strategy: ReadFromReplicaStrategy, topology_hash: TopologyHash, + pub(crate) refresh_conn_state: RefreshConnectionStates, } impl Drop for ConnectionsContainer { @@ -155,6 +363,7 @@ impl Default for ConnectionsContainer { slot_map: Default::default(), read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, topology_hash: 0, + refresh_conn_state: Default::default(), } } } @@ -182,6 +391,7 @@ where slot_map, read_from_replica_strategy, topology_hash, + refresh_conn_state: Default::default(), } } @@ -337,6 +547,51 @@ where }) } + // Fetches the master address for a given route. + // Returns `None` if no master address can be resolved. + pub(crate) fn address_for_route(&self, route: &Route) -> Option { + let slot_map_value = self.slot_map.slot_value_for_route(route)?; + Some(slot_map_value.addrs.primary().clone().to_string()) + } + + // Retrieves the notifier for a reconnect task associated with a given route. + // Returns `Some(Arc)` if a reconnect task is in the `Reconnecting` state. + // Returns `None` if: + // - There is no refresh task for the route's address. + // - The reconnect task is in `ReconnectingTooLong` state, with a debug log for clarity. + pub(crate) fn notifier_for_route(&self, route: &Route) -> Option> { + let address = self.address_for_route(route)?; + + if let Some(task_state) = self + .refresh_conn_state + .refresh_address_in_progress + .get(&address) + { + match &task_state.status { + RefreshTaskStatus::Reconnecting(notifier) => { + debug!( + "notifier_for_route: Found reconnect notifier for address: {}", + address + ); + Some(notifier.get_notifier()) + } + RefreshTaskStatus::ReconnectingTooLong => { + debug!( + "notifier_for_route: Address {} is in ReconnectingTooLong state. No notifier will be returned.", + address + ); + None + } + } + } else { + debug!( + "notifier_for_route: No refresh task exists for address: {}. No notifier will be returned.", + address + ); + None + } + } + pub(crate) fn all_node_connections( &self, ) -> impl Iterator> + '_ { @@ -572,6 +827,7 @@ mod tests { connection_map, read_from_replica_strategy: ReadFromReplicaStrategy::AZAffinity("use-1a".to_string()), topology_hash: 0, + refresh_conn_state: Default::default(), } } @@ -628,6 +884,7 @@ mod tests { connection_map, read_from_replica_strategy: strategy, topology_hash: 0, + refresh_conn_state: Default::default(), } } diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index 534fdd429e..3334c620fe 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -40,10 +40,13 @@ use crate::{ commands::cluster_scan::{cluster_scan, ClusterScanArgs, ScanStateRC}, FromRedisValue, InfoDict, }; +use connections_container::{RefreshTaskState, RefreshTaskStatus}; use dashmap::DashMap; use std::{ collections::{HashMap, HashSet}, - fmt, io, mem, + fmt, io, + iter::once, + mem, net::{IpAddr, SocketAddr}, pin::Pin, sync::{ @@ -1301,7 +1304,7 @@ where .extend_connection_map(connection_map); if let Err(err) = Self::refresh_slots_and_subscriptions_with_retries( inner.clone(), - &RefreshPolicy::Throttable, + &RefreshPolicy::NotThrottable, ) .await { @@ -1341,13 +1344,13 @@ where } // identify nodes with closed connection - let mut addrs_to_refresh = Vec::new(); + let mut addrs_to_refresh = HashSet::new(); for (addr, con_fut) in &all_valid_conns { let con = con_fut.clone().await; // connection object might be present despite the transport being closed if con.is_closed() { // transport is closed, need to refresh - addrs_to_refresh.push(addr.clone()); + addrs_to_refresh.insert(addr.clone()); } } @@ -1361,7 +1364,7 @@ where if !addrs_to_refresh.is_empty() { // don't try existing nodes since we know a. it does not exist. b. exist but its connection is closed - Self::refresh_connections( + Self::refresh_and_update_connections( inner.clone(), addrs_to_refresh, RefreshConnectionType::AllConnections, @@ -1371,62 +1374,189 @@ where } } - async fn refresh_connections( + // Creates refresh tasks, await on the tasks' notifier and the update the connection_container. + // Awaiting on the notifier guaranties at least one reconnect attempt on each address. + async fn refresh_and_update_connections( inner: Arc>, - addresses: Vec, + addresses: HashSet, conn_type: RefreshConnectionType, check_existing_conn: bool, ) { - info!("Started refreshing connections to {:?}", addresses); - let mut tasks = FuturesUnordered::new(); - let inner = inner.clone(); + trace!("refresh_and_update_connections: calling trigger_refresh_connection_tasks"); + Self::trigger_refresh_connection_tasks( + inner.clone(), + addresses.clone(), + conn_type, + check_existing_conn, + ) + .await; - for address in addresses.into_iter() { - let inner = inner.clone(); + trace!("refresh_and_update_connections: Await on all tasks' refresh notifier"); + // Await on all tasks' refresh notifier if exists + let refresh_task_notifiers = inner + .clone() + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .collect_refresh_notifiers(&addresses); + let futures: Vec<_> = refresh_task_notifiers + .iter() + .map(|notify| notify.notified()) + .collect(); + futures::future::join_all(futures).await; - tasks.push(async move { - let node_option = if check_existing_conn { - let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); - connections_container.remove_node(&address) - } else { - None - }; + // Update the connections in the connection_container + Self::update_refreshed_connection(inner); + } - // Override subscriptions for this connection - let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone(); - let subs_guard = inner.subscriptions_by_address.read().await; - cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); + async fn trigger_refresh_connection_tasks( + inner: Arc>, + addresses: HashSet, + conn_type: RefreshConnectionType, + check_existing_conn: bool, + ) { + debug!("Triggering refresh connections tasks to {:?} ", addresses); + + for address in addresses { + if inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_address_in_progress + .contains_key(&address) + { + info!("Skipping refresh for {}: already in progress", address); + continue; + } + + let inner_clone = inner.clone(); + let address_clone = address.clone(); + let address_clone_for_task = address.clone(); + + // Add this address to be removed in poll_flush so all requests see a consistent connection map. + // See next comment for elaborated explanation. + inner_clone + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_addresses_started + .insert(address_clone_for_task.clone()); + + let node_option = if check_existing_conn { + let connections_container = inner_clone.conn_lock.read().expect(MUTEX_READ_ERR); + connections_container + .connection_map() + .get(&address_clone_for_task) + .map(|node| node.value().clone()) + } else { + None + }; + + let handle = tokio::spawn(async move { + info!( + "refreshing connection task to {:?} started", + address_clone_for_task + ); + + let mut cluster_params = inner_clone + .cluster_params + .read() + .expect(MUTEX_READ_ERR) + .clone(); + let subs_guard = inner_clone.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = + subs_guard.get(&address_clone_for_task).cloned(); drop(subs_guard); - let node = get_or_create_conn( - &address, + let node_result = get_or_create_conn( + &address_clone_for_task, node_option, &cluster_params, conn_type, - inner.glide_connection_options.clone(), + inner_clone.glide_connection_options.clone(), ) .await; - (address, node) - }); - } - - // Poll connection tasks as soon as each one finishes - while let Some(result) = tasks.next().await { - match result { - (address, Ok(node)) => { - let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); - connections_container.replace_or_add_connection_for_address(address, node); + // Maintain the newly refreshed connection separately from the main connection map. + // This refreshed connection will be incorporated into the main connection map at the start of the poll_flush operation. + // This approach ensures that all requests within the current batch interact with a consistent connection map, + // preventing potential reordering issues. + // + // By delaying the integration of the refreshed connection: + // + // 1. We maintain consistency throughout the processing of a batch of requests. + // 2. We avoid mid-batch changes to the connection map that could lead to inconsistent routing or ordering of operations. + // 3. We ensure that all requests in a batch see the same cluster topology, reducing the risk of race conditions or unexpected behavior. + // + // This strategy effectively creates a synchronization point at the beginning of poll_flush, where the connection map is + // updated atomically for the next batch of operations. This approach balances the need for up-to-date connection information + // with the requirement for consistent request handling within each processing cycle. + match node_result { + Ok(node) => { + debug!( + "Succeeded to refresh connection for node {}.", + address_clone_for_task + ); + inner_clone + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_addresses_done + .insert(address_clone_for_task.clone(), Some(node)); + } + Err(err) => { + warn!( + "Failed to refresh connection for node {}. Error: `{:?}`", + address_clone_for_task, err + ); + // TODO - When we move to retry more than once, we add this address to a new set of running to long, and then only move + // the RefreshTaskState.status to RunningTooLong in the poll_flush context inside update_refreshed_connection. + inner_clone + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_addresses_done + .insert(address_clone_for_task.clone(), None); + } } - (address, Err(err)) => { + + // Need to notify here the awaitng requests inorder to awaket the context of the poll_flush as + // it awaits on this notifier inside the get_connection in the poll_next inside poll_complete. + // Otherwise poll_flush won't be polled until the next start_send or other requests I/O. + if let Some(task_state) = inner_clone + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_address_in_progress + .get_mut(&address_clone_for_task) + { + task_state.status.notify_waiting_requests(); + } else { warn!( - "Failed to refresh connection for node {}. Error: `{:?}`", - address, err + "No refresh task state found for address: {}", + address_clone_for_task ); } - } + + info!("Refreshing connection task to {:?} is done", address_clone); + }); + + // Keep the task handle into the RefreshState of this address + inner + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_address_in_progress + .insert(address.clone(), RefreshTaskState::new(handle)); } - debug!("refresh connections completed"); + debug!("trigger_refresh_connection_tasks: Done"); } async fn aggregate_results( @@ -1760,9 +1890,9 @@ where if !addrs_to_refresh.is_empty() { // immediately trigger connection reestablishment - Self::refresh_connections( + Self::refresh_and_update_connections( inner.clone(), - addrs_to_refresh.into_iter().collect(), + addrs_to_refresh, RefreshConnectionType::AllConnections, false, ) @@ -1796,9 +1926,10 @@ where } if !failed_connections.is_empty() { - Self::refresh_connections( + trace!("check_for_topology_diff: calling trigger_refresh_connection_tasks"); + Self::trigger_refresh_connection_tasks( inner, - failed_connections, + failed_connections.into_iter().collect::>(), RefreshConnectionType::OnlyManagementConnection, true, ) @@ -1899,6 +2030,9 @@ where info!("refresh_slots found nodes:\n{new_connections}"); // Reset the current slot map and connection vector with the new ones let mut write_guard = inner.conn_lock.write().expect(MUTEX_WRITE_ERR); + // Clear the refresh tasks of the prev instance + // TODO - Maybe we can take the running refresh tasks and use them instead of running new connection creation + write_guard.refresh_conn_state.clear_refresh_state(); let read_from_replicas = inner .get_cluster_param(|params| params.read_from_replicas.clone()) .expect(MUTEX_READ_ERR); @@ -2215,13 +2349,16 @@ where ) } InternalSingleNodeRouting::SpecificNode(route) => { - if let Some((conn, address)) = core - .conn_lock - .read() - .expect(MUTEX_READ_ERR) - .connection_for_route(&route) - { - ConnectionCheck::Found((conn, address)) + // Step 1: Attempt to get the connection directly using the route. + let conn_check = { + let conn_lock = core.conn_lock.read().expect(MUTEX_READ_ERR); + conn_lock + .connection_for_route(&route) + .map(ConnectionCheck::Found) + }; + + if let Some(conn_check) = conn_check { + conn_check } else { // No connection is found for the given route: // - For key-based commands, attempt redirection to a random node, @@ -2229,6 +2366,9 @@ where // - For non-key-based commands, avoid attempting redirection to a random node // as it wouldn't result in MOVED hints and can lead to unwanted results // (e.g., sending management command to a different node than the user asked for); instead, raise the error. + let mut conn_check = ConnectionCheck::RandomConnection; + + // Step 2: Handle cases where no connection is found for the route. let routable_cmd = cmd.and_then(|cmd| Routable::command(&*cmd)); if routable_cmd.is_some() && !RoutingInfo::is_key_routing_command(&routable_cmd.unwrap()) @@ -2239,10 +2379,51 @@ where format!("{route:?}"), ) .into()); + } + + debug!( + "SpecificNode: No connection found for route `{route:?}`. Checking for reconnect tasks before redirecting to a random node." + ); + + // Step 3: Obtain the reconnect notifier, ensuring the lock is released immediately after. + let reconnect_notifier = { + let conn_lock = core.conn_lock.write().expect(MUTEX_READ_ERR); + conn_lock.notifier_for_route(&route).clone() + }; + + // Step 4: If a notifier exists, wait for it to signal completion. + if let Some(notifier) = reconnect_notifier { + debug!( + "SpecificNode: Waiting on reconnect notifier for route `{route:?}`." + ); + + // Drop the lock before awaiting + notifier.notified().await; + + debug!( + "SpecificNode: Finished waiting on notifier for route `{route:?}`. Retrying connection lookup." + ); + + // Step 5: Retry the connection lookup after waiting for the reconnect task. + if let Some((conn, address)) = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_route(&route) + { + conn_check = ConnectionCheck::Found((conn, address)); + } else { + debug!( + "SpecificNode: No connection found for route `{route:?}` after waiting on reconnect notifier. Proceeding to random node." + ); + } } else { - warn!("No connection found for route `{route:?}`. Attempting redirection to a random node."); - ConnectionCheck::RandomConnection + debug!( + "SpecificNode: No active reconnect task for route `{route:?}`. Proceeding to random node." + ); } + + conn_check } } InternalSingleNodeRouting::Random => ConnectionCheck::RandomConnection, @@ -2270,32 +2451,80 @@ where let (address, mut conn) = match conn_check { ConnectionCheck::Found((address, connection)) => (address, connection.await), - ConnectionCheck::OnlyAddress(addr) => { - let mut this_conn_params = core.get_cluster_param(|params| params.clone())?; - let subs_guard = core.subscriptions_by_address.read().await; - this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); - drop(subs_guard); - match connect_and_check::( - &addr, - this_conn_params, - None, + ConnectionCheck::OnlyAddress(address) => { + // No connection for this address in the conn_map + Self::trigger_refresh_connection_tasks( + core.clone(), + HashSet::from_iter(once(address.clone())), RefreshConnectionType::AllConnections, - None, - core.glide_connection_options.clone(), + false, ) - .await - .get_node() + .await; + + let reconnect_notifier: Option> = match core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_address_in_progress + .get(&address) { - Ok(node) => { - let connection_clone = node.user_connection.conn.clone().await; - let connections = core.conn_lock.read().expect(MUTEX_READ_ERR); - let address = connections.replace_or_add_connection_for_address(addr, node); - drop(connections); - (address, connection_clone) + Some(refresh_task_state) => { + match &refresh_task_state.status { + // If the task status is `Reconnecting`, grab the notifier. + RefreshTaskStatus::Reconnecting(refresh_notifier) => { + Some(refresh_notifier.get_notifier()) + } + RefreshTaskStatus::ReconnectingTooLong => { + debug!( + "get_connection: Address {} is in ReconnectingTooLong state, skipping notifier wait.", + address + ); + None + } + } } - Err(err) => { - return Err(err); + None => { + debug!( + "get_connection: No refresh task found in progress for address: {}", + address + ); + None } + }; + + let mut conn_option = None; + if let Some(refresh_notifier) = reconnect_notifier { + debug!( + "get_connection: Waiting on the refresh notifier for address: {}", + address + ); + // Wait for the refresh task to notify that it's done reconnecting (or transitioning). + refresh_notifier.notified().await; + debug!( + "get_connection: After waiting on the refresh notifier for address: {}", + address + ); + + conn_option = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_address(&address); + } + + if let Some((address, conn)) = conn_option { + debug!("get_connection: Connection found for address: {}", address); + // If found, return the connection + (address, conn.await) + } else { + // Otherwise, return an error indicating the connection wasn't found + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found", + address, + ) + .into()); } } ConnectionCheck::RandomConnection => { @@ -2326,6 +2555,7 @@ where } fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll> { + trace!("entered poll_recovere"); let recover_future = match &mut self.state { ConnectionState::PollComplete => return Poll::Ready(Ok(())), ConnectionState::Recover(future) => future, @@ -2393,6 +2623,112 @@ where Self::try_request(info, core).await } + fn update_refreshed_connection(inner: Arc>) { + trace!("update_refreshed_connection started"); + loop { + let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); + + // Check if both sets are empty + if connections_container + .refresh_conn_state + .refresh_addresses_started + .is_empty() + && connections_container + .refresh_conn_state + .refresh_addresses_done + .is_empty() + { + break; + } + + let addresses_to_remove: Vec = connections_container + .refresh_conn_state + .refresh_addresses_started + .iter() + .cloned() + .collect(); + + let addresses_done: Vec = connections_container + .refresh_conn_state + .refresh_addresses_done + .keys() + .cloned() + .collect(); + + drop(connections_container); + + // Process refresh_addresses_started + for address in addresses_to_remove { + inner + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .refresh_conn_state + .refresh_addresses_started + .remove(&address); + inner + .conn_lock + .write() + .expect(MUTEX_READ_ERR) + .remove_node(&address); + } + + // Process refresh_addresses_done + for address in addresses_done { + // Check if the address exists in refresh_addresses_done + let mut conn_lock_write = inner.conn_lock.write().expect(MUTEX_READ_ERR); + if let Some(conn_option) = conn_lock_write + .refresh_conn_state + .refresh_addresses_done + .get_mut(&address) + { + // Match the content of the Option + match conn_option.take() { + Some(conn) => { + debug!( + "update_refreshed_connection: found refreshed connection for address {}", + address + ); + // Move the node_conn to the function + conn_lock_write + .replace_or_add_connection_for_address(address.clone(), conn); + } + None => { + debug!( + "update_refreshed_connection: task completed, but no connection for address {}", + address + ); + } + } + } + + // Remove this address from refresh_addresses_done + conn_lock_write + .refresh_conn_state + .refresh_addresses_done + .remove(&address); + + // Remove this entry from refresh_address_in_progress + if conn_lock_write + .refresh_conn_state + .refresh_address_in_progress + .remove(&address) + .is_some() + { + debug!( + "update_refreshed_connection: Successfully removed refresh state for address: {}", + address + ); + } else { + warn!( + "update_refreshed_connection: No refresh state found to remove for address: {:?}", + address + ); + } + } + } + } + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { let retry_params = self .inner @@ -2497,8 +2833,8 @@ where } } Next::Reconnect { request, target } => { - poll_flush_action = - poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); + poll_flush_action = poll_flush_action + .change_state(PollFlushAction::Reconnect(HashSet::from_iter([target]))); if let Some(request) = request { self.inner.pending_requests.lock().unwrap().push(request); } @@ -2543,7 +2879,7 @@ where enum PollFlushAction { None, RebuildSlots, - Reconnect(Vec), + Reconnect(HashSet), ReconnectFromInitialConnections, } @@ -2617,6 +2953,12 @@ where return Poll::Pending; } + // Updating the connection_map with all the refreshed_connections + // In case of active poll_recovery, the work should + // take care of the refreshed_connection, add them if still relevant, and kill the refresh_tasks of + // non-relevant addresses. + ClusterConnInner::update_refreshed_connection(self.inner.clone()); + match ready!(self.poll_complete(cx)) { PollFlushAction::None => return Poll::Ready(Ok(())), PollFlushAction::RebuildSlots => { @@ -2629,7 +2971,7 @@ where } PollFlushAction::Reconnect(addresses) => { self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( - ClusterConnInner::refresh_connections( + ClusterConnInner::trigger_refresh_connection_tasks( self.inner.clone(), addresses, RefreshConnectionType::OnlyUserConnection,