Skip to content

Commit

Permalink
Use single NATS connection
Browse files Browse the repository at this point in the history
  • Loading branch information
nazar-pc committed Jun 6, 2024
1 parent 0f6f7fd commit 6d13125
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 64 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/subspace-farmer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include = [
[dependencies]
anyhow = "1.0.82"
async-lock = "3.3.0"
async-nats = "0.35.0"
async-nats = "0.35.1"
async-trait = "0.1.80"
backoff = { version = "0.4.0", features = ["futures", "tokio"] }
base58 = "0.2.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use prometheus_client::registry::Registry;
use std::env::current_exe;
use std::mem;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use subspace_farmer::cluster::nats_client::NatsClient;
use subspace_farmer::utils::AsyncJoinOnDrop;
use subspace_metrics::{start_prometheus_metrics_server, RegistryAdapter};
Expand Down Expand Up @@ -53,12 +52,6 @@ struct SharedArgs {
/// which can be done by starting NATS server with config file containing `max_payload = 2MB`.
#[arg(long, alias = "nats-server", required = true)]
nats_servers: Vec<ServerAddr>,
/// Size of connection pool of NATS clients.
///
/// Pool size can be increased in case of large number of farms or high plotting capacity of
/// this instance.
#[arg(long, default_value = "8")]
nats_pool_size: NonZeroUsize,
/// Defines endpoints for the prometheus metrics server. It doesn't start without at least
/// one specified endpoint. Format: 127.0.0.1:8080
#[arg(long, aliases = ["metrics-endpoint", "metrics-endpoints"])]
Expand Down Expand Up @@ -101,7 +94,6 @@ where
} = cluster_args;
let SharedArgs {
nats_servers,
nats_pool_size,
prometheus_listen_on,
} = shared_args;
let ClusterSubcommands { mut subcommand } = subcommands;
Expand All @@ -112,7 +104,6 @@ where
max_elapsed_time: None,
..ExponentialBackoff::default()
},
nats_pool_size,
)
.await
.map_err(|error| anyhow!("Failed to connect to NATS server: {error}"))?;
Expand Down
81 changes: 29 additions & 52 deletions crates/subspace-farmer/src/cluster/nats_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@ use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
use derive_more::{Deref, DerefMut};
use futures::channel::mpsc;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt};
use parity_scale_codec::{Decode, Encode};
use std::any::type_name;
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
Expand Down Expand Up @@ -273,11 +270,8 @@ impl<Response> StreamResponseSubscriber<Response> {

let background_task = AsyncJoinOnDrop::new(
tokio::spawn(async move {
// Make sure to use the same exact NATS connection for all acknowledgements in order to
// ensure consistent ordering
let client = &*nats_client;
while let Some((subject, index)) = acknowledgement_receiver.next().await {
if let Err(error) = client
if let Err(error) = nats_client
.publish(subject.clone(), index.to_le_bytes().to_vec().into())
.await
{
Expand Down Expand Up @@ -363,8 +357,7 @@ where

#[derive(Debug)]
struct Inner {
clients: Vec<Client>,
next_client: AtomicUsize,
client: Client,
request_retry_backoff_policy: ExponentialBackoff,
approximate_max_message_size: usize,
}
Expand All @@ -380,7 +373,7 @@ impl Deref for NatsClient {

#[inline]
fn deref(&self) -> &Self::Target {
self.client()
&self.inner.client
}
}

Expand All @@ -389,37 +382,24 @@ impl NatsClient {
pub async fn new<A: ToServerAddrs>(
addrs: A,
request_retry_backoff_policy: ExponentialBackoff,
nats_pool_size: NonZeroUsize,
) -> Result<Self, async_nats::Error> {
let servers = addrs.to_server_addrs()?.collect::<Vec<_>>();
Self::from_clients(
(0..nats_pool_size.get())
.map(|_| async {
async_nats::connect_with_options(
&servers,
ConnectOptions::default().request_timeout(Some(REQUEST_TIMEOUT)),
)
.await
})
.collect::<FuturesUnordered<_>>()
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?,
Self::from_client(
async_nats::connect_with_options(
&servers,
ConnectOptions::default().request_timeout(Some(REQUEST_TIMEOUT)),
)
.await?,
request_retry_backoff_policy,
)
}

/// Create new client from existing NATS instance
pub fn from_clients(
clients: Vec<Client>,
pub fn from_client(
client: Client,
request_retry_backoff_policy: ExponentialBackoff,
) -> Result<Self, async_nats::Error> {
let max_payload = clients
.first()
.ok_or("Empty list of NATS clients is not supported; qed")?
.server_info()
.max_payload;
let max_payload = client.server_info().max_payload;
if max_payload < EXPECTED_MESSAGE_SIZE {
return Err(format!(
"Max payload {max_payload} is smaller than expected {EXPECTED_MESSAGE_SIZE}, \
Expand All @@ -429,8 +409,7 @@ impl NatsClient {
}

let inner = Inner {
clients,
next_client: AtomicUsize::default(),
client,
request_retry_backoff_policy,
// Allow up to 90%, the rest will be wrapper data structures, etc.
approximate_max_message_size: max_payload * 9 / 10,
Expand Down Expand Up @@ -460,7 +439,8 @@ impl NatsClient {
let mut maybe_retry_backoff = None;
let message = loop {
match self
.client()
.inner
.client
.request(subject.clone(), request.encode().into())
.await
{
Expand Down Expand Up @@ -529,12 +509,14 @@ impl NatsClient {
let stream_request = StreamRequest::new(request);

let subscriber = self
.client()
.inner
.client
.subscribe(stream_request.response_subject.clone())
.await?;
debug!(request_type = %type_name::<Request>(), ?subscriber, "Stream request subscription");

self.client()
self.inner
.client
.publish(
subject_with_instance(Request::SUBJECT, instance),
stream_request.encode().into(),
Expand All @@ -554,15 +536,12 @@ impl NatsClient {
GenericStreamResponses<<Request as GenericStreamRequest>::Response>;

let mut response_stream = response_stream.fuse();
// Make sure to use the same exact NATS connection for all acknowledgements in order to
// ensure consistent ordering
let client = &**self;

// Pull the first element to measure response size
let first_element = match response_stream.next().await {
Some(first_element) => first_element,
None => {
if let Err(error) = client
if let Err(error) = self
.publish(
response_subject.clone(),
Response::<Request>::Last {
Expand Down Expand Up @@ -648,7 +627,7 @@ impl NatsClient {
}
};

if let Err(error) = client
if let Err(error) = self
.publish(response_subject.clone(), response.encode().into())
.await
{
Expand Down Expand Up @@ -746,7 +725,8 @@ impl NatsClient {
where
Notification: GenericNotification,
{
self.client()
self.inner
.client
.publish(
subject_with_instance(Notification::SUBJECT, instance),
notification.encode().into(),
Expand All @@ -763,7 +743,8 @@ impl NatsClient {
where
Broadcast: GenericBroadcast,
{
self.client()
self.inner
.client
.publish_with_headers(
Broadcast::SUBJECT.replace('*', instance),
{
Expand Down Expand Up @@ -820,12 +801,6 @@ impl NatsClient {
.await
}

/// Get NATS client from a pool
fn client(&self) -> &Client {
let client = self.inner.next_client.fetch_add(1, Ordering::Relaxed);
&self.inner.clients[client % self.inner.clients.len()]
}

/// Simple subscription that will produce decoded messages, while skipping messages that fail to
/// decode
async fn simple_subscribe<Message>(
Expand All @@ -838,11 +813,13 @@ impl NatsClient {
Message: Decode,
{
let subscriber = if let Some(queue_group) = queue_group {
self.client()
self.inner
.client
.queue_subscribe(subject_with_instance(subject, instance), queue_group)
.await?
} else {
self.client()
self.inner
.client
.subscribe(subject_with_instance(subject, instance))
.await?
};
Expand Down

0 comments on commit 6d13125

Please sign in to comment.