diff --git a/Cargo.lock b/Cargo.lock index c487697..b179fc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -930,6 +930,20 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "der" version = "0.7.9" @@ -5214,6 +5228,7 @@ version = "0.0.0" dependencies = [ "async-trait", "chrono", + "dashmap", "futures", "hex", "http", diff --git a/Cargo.toml b/Cargo.toml index 76119c8..4b0c6b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,3 +74,4 @@ tower = { version = "0.4", features = ["buffer", "util"] } async-trait = "0.1" chrono = "0.4" jsonrpc-core = "18.0.0" +dashmap = "6.1" diff --git a/zaino-fetch/src/chain/mempool.rs b/zaino-fetch/src/chain/mempool.rs index 481399c..f673ecd 100644 --- a/zaino-fetch/src/chain/mempool.rs +++ b/zaino-fetch/src/chain/mempool.rs @@ -130,9 +130,17 @@ impl Mempool { let mut txids_to_exclude: HashSet = HashSet::new(); for exclude_txid in &exclude_txids { + // Convert to big endian (server format). + let server_exclude_txid: String = exclude_txid + .chars() + .collect::>() + .chunks(2) + .rev() + .map(|chunk| chunk.iter().collect::()) + .collect(); let matching_txids: Vec<&String> = mempool_txids .iter() - .filter(|txid| txid.starts_with(exclude_txid)) + .filter(|txid| txid.starts_with(&server_exclude_txid)) .collect(); if matching_txids.len() == 1 { diff --git a/zaino-state/Cargo.toml b/zaino-state/Cargo.toml index 0e94b96..7b2af4a 100644 --- a/zaino-state/Cargo.toml +++ b/zaino-state/Cargo.toml @@ -31,6 +31,7 @@ futures = { workspace = true } tonic = { workspace = true } http = { workspace = true } lazy-regex = { workspace = true } +dashmap = { workspace = true } [dev-dependencies] zaino-testutils = { path = "../zaino-testutils" } diff --git a/zaino-state/src/broadcast.rs b/zaino-state/src/broadcast.rs new file mode 100644 index 0000000..0ae659a --- /dev/null +++ b/zaino-state/src/broadcast.rs @@ -0,0 +1,236 @@ +//! Holds zaino-state::Broadcast, a thread safe broadcaster used by the mempool and non-finalised state. + +use dashmap::DashMap; +use std::{collections::HashSet, hash::Hash, sync::Arc}; +use tokio::sync::watch; + +use crate::status::StatusType; + +/// A generic, thread-safe broadcaster that manages mutable state and notifies clients of updates. +#[derive(Clone)] +pub struct Broadcast { + state: Arc>, + notifier: watch::Sender, +} + +impl Broadcast { + /// Creates a new Broadcast instance, uses default dashmap spec. + pub fn new_default() -> Self { + let (notifier, _) = watch::channel(StatusType::Spawning); + Self { + state: Arc::new(DashMap::new()), + notifier, + } + } + + /// Creates a new Broadcast instance, exposes dashmap spec. + pub fn new_custom(capacity: usize, shard_amount: usize) -> Self { + let (notifier, _) = watch::channel(StatusType::Spawning); + Self { + state: Arc::new(DashMap::with_capacity_and_shard_amount( + capacity, + shard_amount, + )), + notifier, + } + } + + /// Inserts or updates an entry in the state and broadcasts an update. + pub fn insert(&self, key: K, value: V, status: StatusType) { + self.state.insert(key, value); + let _ = self.notifier.send(status); + } + + /// Inserts or updates an entry in the state and broadcasts an update. + pub fn insert_set(&self, set: Vec<(K, V)>, status: StatusType) { + for (key, value) in set { + self.state.insert(key, value); + } + let _ = self.notifier.send(status); + } + + /// Inserts only new entries from the set into the state and broadcasts an update. + pub fn insert_filtered_set(&self, set: Vec<(K, V)>, status: StatusType) { + for (key, value) in set { + // Check if the key is already in the map + if self.state.get(&key).is_none() { + self.state.insert(key, value); + } + } + let _ = self.notifier.send(status); + } + + /// Removes an entry from the state and broadcasts an update. + pub fn remove(&self, key: &K, status: StatusType) { + self.state.remove(key); + let _ = self.notifier.send(status); + } + + /// Retrieves a value from the state by key. + pub fn get(&self, key: &K) -> Option> { + self.state + .get(key) + .map(|entry| Arc::new((*entry.value()).clone())) + } + + /// Retrieves a set of values from the state by a list of keys. + pub fn get_set(&self, keys: &[K]) -> Vec<(K, Arc)> { + keys.iter() + .filter_map(|key| { + self.state + .get(key) + .map(|entry| (key.clone(), Arc::new((*entry.value()).clone()))) + }) + .collect() + } + + /// Checks if a key exists in the state. + pub fn contains_key(&self, key: &K) -> bool { + self.state.contains_key(key) + } + + /// Returns a receiver to listen for state update notifications. + pub fn subscribe(&self) -> watch::Receiver { + self.notifier.subscribe() + } + + /// Returns a [`BroadcastSubscriber`] to the [`Broadcast`]. + pub fn subscriber(&self) -> BroadcastSubscriber { + BroadcastSubscriber { + state: self.get_state(), + notifier: self.subscribe(), + } + } + + /// Provides read access to the internal state. + pub fn get_state(&self) -> Arc> { + Arc::clone(&self.state) + } + + /// Returns the whole state excluding keys in the ignore list. + pub fn get_filtered_state(&self, ignore_list: &HashSet) -> Vec<(K, V)> { + self.state + .iter() + .filter(|entry| !ignore_list.contains(entry.key())) + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect() + } + + /// Clears all entries from the state. + pub fn clear(&self) { + self.state.clear(); + } + + /// Returns the number of entries in the state. + pub fn len(&self) -> usize { + self.state.len() + } + + /// Returns true if the state is empty. + pub fn is_empty(&self) -> bool { + self.state.is_empty() + } + + /// Broadcasts an update. + pub fn notify(&self, status: StatusType) { + if self.notifier.send(status).is_err() { + eprintln!("No subscribers are currently listening for updates."); + } + } +} + +impl Default for Broadcast { + fn default() -> Self { + Self::new_default() + } +} + +impl std::fmt::Debug + for Broadcast +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let state_contents: Vec<_> = self + .state + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + f.debug_struct("Broadcast") + .field("state", &state_contents) + .field("notifier", &"watch::Sender") + .finish() + } +} + +/// A generic, thread-safe broadcaster that manages mutable state and notifies clients of updates. +#[derive(Clone)] +pub struct BroadcastSubscriber { + state: Arc>, + notifier: watch::Receiver, +} + +impl BroadcastSubscriber { + /// Waits on notifier update and returns StatusType. + pub async fn wait_on_notifier(&mut self) -> Result { + self.notifier.changed().await?; + let status = self.notifier.borrow().clone(); + Ok(status) + } + + /// Retrieves a value from the state by key. + pub fn get(&self, key: &K) -> Option> { + self.state + .get(key) + .map(|entry| Arc::new((*entry.value()).clone())) + } + + /// Retrieves a set of values from the state by a list of keys. + pub fn get_set(&self, keys: &[K]) -> Vec<(K, Arc)> { + keys.iter() + .filter_map(|key| { + self.state + .get(key) + .map(|entry| (key.clone(), Arc::new((*entry.value()).clone()))) + }) + .collect() + } + + /// Checks if a key exists in the state. + pub fn contains_key(&self, key: &K) -> bool { + self.state.contains_key(key) + } + + /// Returns the whole state excluding keys in the ignore list. + pub fn get_filtered_state(&self, ignore_list: &HashSet) -> Vec<(K, V)> { + self.state + .iter() + .filter(|entry| !ignore_list.contains(entry.key())) + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect() + } + + /// Returns the number of entries in the state. + pub fn len(&self) -> usize { + self.state.len() + } + + /// Returns true if the state is empty. + pub fn is_empty(&self) -> bool { + self.state.is_empty() + } +} + +impl std::fmt::Debug + for BroadcastSubscriber +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let state_contents: Vec<_> = self + .state + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + f.debug_struct("Broadcast") + .field("state", &state_contents) + .field("notifier", &"watch::Sender") + .finish() + } +} diff --git a/zaino-state/src/error.rs b/zaino-state/src/error.rs index d89f7a7..25c40e2 100644 --- a/zaino-state/src/error.rs +++ b/zaino-state/src/error.rs @@ -85,6 +85,10 @@ pub enum FetchServiceError { #[error("JsonRpcConnector error: {0}")] JsonRpcConnectorError(#[from] zaino_fetch::jsonrpc::error::JsonRpcConnectorError), + /// Error from the mempool. + #[error("Mempool error: {0}")] + MempoolError(#[from] MempoolError), + /// RPC error in compatibility with zcashd. #[error("RPC error: {0:?}")] RpcError(#[from] zaino_fetch::jsonrpc::connector::RpcError), @@ -132,6 +136,9 @@ impl From for tonic::Status { FetchServiceError::JsonRpcConnectorError(err) => { tonic::Status::internal(format!("JsonRpcConnector error: {}", err)) } + FetchServiceError::MempoolError(err) => { + tonic::Status::internal(format!("Mempool error: {}", err)) + } FetchServiceError::RpcError(err) => { tonic::Status::internal(format!("RPC error: {:?}", err)) } @@ -160,3 +167,45 @@ impl From for tonic::Status { } } } + +/// Errors related to the `StateService`. +#[derive(Debug, thiserror::Error)] +pub enum MempoolError { + /// Custom Errors. *Remove before production. + #[error("Custom error: {0}")] + Custom(String), + + /// Error from a Tokio JoinHandle. + #[error("Join error: {0}")] + JoinError(#[from] tokio::task::JoinError), + + /// Error from JsonRpcConnector. + #[error("JsonRpcConnector error: {0}")] + JsonRpcConnectorError(#[from] zaino_fetch::jsonrpc::error::JsonRpcConnectorError), + + /// Error from a Tokio Watch Reciever. + #[error("Join error: {0}")] + WatchRecvError(#[from] tokio::sync::watch::error::RecvError), + + /// Unexpected status-related error. + #[error("Status error: {0:?}")] + StatusError(StatusError), + + /// Error from sending to a Tokio MPSC channel. + #[error("Send error: {0}")] + SendError( + #[from] + tokio::sync::mpsc::error::SendError< + Result<(crate::mempool::MempoolKey, crate::mempool::MempoolValue), StatusError>, + >, + ), + + /// A generic boxed error. + #[error("Generic error: {0}")] + Generic(#[from] Box), +} + +/// A general error type to represent error StatusTypes. +#[derive(Debug, Clone, thiserror::Error)] +#[error("Unexpected status error: {0:?}")] +pub struct StatusError(pub crate::status::StatusType); diff --git a/zaino-state/src/fetch.rs b/zaino-state/src/fetch.rs index d733985..05ff439 100644 --- a/zaino-state/src/fetch.rs +++ b/zaino-state/src/fetch.rs @@ -5,6 +5,7 @@ use crate::{ error::FetchServiceError, get_build_info, indexer::{LightWalletIndexer, ZcashIndexer}, + mempool::{Mempool, MempoolSubscriber}, status::{AtomicStatus, StatusType}, stream::{ AddressStream, CompactBlockStream, CompactTransactionStream, RawTransactionStream, @@ -33,14 +34,17 @@ use zebra_rpc::methods::{ GetBlockChainInfo, GetBlockTransaction, GetInfo, GetRawTransaction, SentTransactionHash, }; -/// Chain fetch service backed by Zcashds JsonRPC service. +/// Chain fetch service backed by Zcashd's JsonRPC engine. +/// +/// This service is a central service, [`FetchServiceSubscriber`] should be created to fetch data. +/// This is done to enable large numbers of concurrent subscribers without significant slowdowns. #[derive(Debug, Clone)] pub struct FetchService { /// JsonRPC Client. fetcher: JsonRpcConnector, // TODO: Add Internal Non-Finalised State - /// Sync task handle. - // sync_task_handle: tokio::task::JoinHandle<()>, + /// Internal mempool. + mempool: Mempool, /// Service metadata. data: ServiceMetadata, /// StateService config data. @@ -66,6 +70,8 @@ impl FetchService { ) .await?; + let mempool = Mempool::spawn(&fetcher, None).await?; + let zebra_build_data = fetcher.get_info().await?; let data = ServiceMetadata { @@ -77,6 +83,7 @@ impl FetchService { let state_service = Self { fetcher, + mempool, data, config, status: AtomicStatus::new(StatusType::Spawning.into()), @@ -84,13 +91,24 @@ impl FetchService { state_service.status.store(StatusType::Syncing.into()); - // TODO: Wait for Non-Finalised state to sync or for mempool to come online. + // TODO: Wait for Non-Finalised state to sync. state_service.status.store(StatusType::Ready.into()); Ok(state_service) } + /// Returns a [`FetchServiceSubscriber`]. + pub fn subscriber(&self) -> FetchServiceSubscriber { + FetchServiceSubscriber { + fetcher: self.fetcher.clone(), + mempool: self.mempool.subscriber(), + data: self.data.clone(), + config: self.config.clone(), + status: self.status.clone(), + } + } + /// Fetches the current status pub fn status(&self) -> StatusType { self.status.load().into() @@ -98,7 +116,7 @@ impl FetchService { /// Shuts down the StateService. pub fn close(&mut self) { - // self.sync_task_handle.abort(); + self.mempool.close(); } } @@ -108,8 +126,33 @@ impl Drop for FetchService { } } +/// A fetch service subscriber. +/// +/// Subscribers should be +#[derive(Debug, Clone)] +pub struct FetchServiceSubscriber { + /// JsonRPC Client. + fetcher: JsonRpcConnector, + // TODO: Add Internal Non-Finalised State + /// Internal mempool. + mempool: MempoolSubscriber, + /// Service metadata. + data: ServiceMetadata, + /// StateService config data. + config: FetchServiceConfig, + /// Thread-safe status indicator. + status: AtomicStatus, +} + +impl FetchServiceSubscriber { + /// Fetches the current status + pub fn status(&self) -> StatusType { + self.status.load().into() + } +} + #[async_trait] -impl ZcashIndexer for FetchService { +impl ZcashIndexer for FetchServiceSubscriber { type Error = FetchServiceError; /// Returns software information from the RPC server, as a [`GetInfo`] JSON struct. @@ -254,7 +297,14 @@ impl ZcashIndexer for FetchService { /// method: post /// tags: blockchain async fn get_raw_mempool(&self) -> Result, Self::Error> { - Ok(self.fetcher.get_raw_mempool().await?.transactions) + // Ok(self.fetcher.get_raw_mempool().await?.transactions) + Ok(self + .mempool + .get_mempool() + .await + .into_iter() + .map(|(key, _)| key.0) + .collect()) } /// Returns information about the given block's Sapling & Orchard tree state. @@ -411,7 +461,7 @@ impl ZcashIndexer for FetchService { } #[async_trait] -impl LightWalletIndexer for FetchService { +impl LightWalletIndexer for FetchServiceSubscriber { type Error = FetchServiceError; /// Return the height of the tip of the best chain @@ -556,7 +606,7 @@ impl LightWalletIndexer for FetchService { let (channel_tx, channel_rx) = tokio::sync::mpsc::channel(self.config.service_channel_size as usize); tokio::spawn(async move { - let timeout = timeout(std::time::Duration::from_secs(service_timeout as u64), async { + let timeout = timeout(std::time::Duration::from_secs((service_timeout*4) as u64), async { for height in start..=end { let height = if rev_order { end - (height - start) @@ -670,7 +720,7 @@ impl LightWalletIndexer for FetchService { let (channel_tx, channel_rx) = tokio::sync::mpsc::channel(self.config.service_channel_size as usize); tokio::spawn(async move { - let timeout = timeout(std::time::Duration::from_secs(service_timeout as u64), async { + let timeout = timeout(std::time::Duration::from_secs((service_timeout*4) as u64), async { for height in start..=end { let height = if rev_order { end - (height - start) @@ -833,7 +883,7 @@ impl LightWalletIndexer for FetchService { let (channel_tx, channel_rx) = tokio::sync::mpsc::channel(self.config.service_channel_size as usize); tokio::spawn(async move { - let timeout = timeout(std::time::Duration::from_secs(service_timeout as u64), async { + let timeout = timeout(std::time::Duration::from_secs((service_timeout*4) as u64), async { for txid in txids { let transaction = fetch_service_clone.get_raw_transaction(txid, Some(1)).await; match transaction { @@ -932,7 +982,7 @@ impl LightWalletIndexer for FetchService { tokio::sync::mpsc::channel::(self.config.service_channel_size as usize); let fetcher_task_handle = tokio::spawn(async move { let fetcher_timeout = timeout( - std::time::Duration::from_secs(service_timeout as u64), + std::time::Duration::from_secs((service_timeout*4) as u64), async { let mut total_balance: u64 = 0; loop { @@ -971,7 +1021,7 @@ impl LightWalletIndexer for FetchService { // NOTE: This timeout is so slow due to the blockcache not being implemented. This should be reduced to 30s once functionality is in place. // TODO: Make [rpc_timout] a configurable system variable with [default = 30s] and [mempool_rpc_timout = 4*rpc_timeout] let addr_recv_timeout = timeout( - std::time::Duration::from_secs(service_timeout as u64), + std::time::Duration::from_secs((service_timeout*4) as u64), async { while let Some(address_result) = request.next().await { // TODO: Hide server error from clients before release. Currently useful for dev purposes. @@ -1037,27 +1087,212 @@ impl LightWalletIndexer for FetchService { /// more bandwidth-efficient; if two or more transactions in the mempool /// match a shortened txid, they are all sent (none is excluded). Transactions /// in the exclude list that don't exist in the mempool are ignored. - /// - /// NOTE: To be implemented with the mempool updgrade. async fn get_mempool_tx( &self, - _request: Exclude, + request: Exclude, ) -> Result { - Err(FetchServiceError::TonicStatusError(tonic::Status::new( - tonic::Code::Unimplemented, - "get_mempool_tx is not implemented in Zaino.", - ))) + let exclude_txids: Vec = request + .txid + .iter() + .map(|txid_bytes| { + let reversed_txid_bytes: Vec = txid_bytes.iter().cloned().rev().collect(); + hex::encode(&reversed_txid_bytes) + }) + .collect(); + + let mempool = self.mempool.clone(); + let service_timeout = self.config.service_timeout; + let (channel_tx, channel_rx) = + tokio::sync::mpsc::channel(self.config.service_channel_size as usize); + tokio::spawn(async move { + let timeout = timeout( + std::time::Duration::from_secs((service_timeout*4) as u64), + async { + for (txid, transaction) in mempool.get_filtered_mempool(exclude_txids).await { + match transaction.0 { + GetRawTransaction::Object(transaction_object) => { + let txid_bytes = match hex::decode(txid.0) { + Ok(bytes) => bytes, + Err(e) => { + if channel_tx + .send(Err(tonic::Status::unknown(e.to_string()))) + .await + .is_err() + { + break; + } else { + continue; + } + } + }; + match ::parse_from_slice( + transaction_object.hex.as_ref(), + Some(vec!(txid_bytes)), None) + { + Ok(transaction) => { + if !transaction.0.is_empty() { + // TODO: Hide server error from clients before release. Currently useful for dev purposes. + if channel_tx + .send(Err(tonic::Status::unknown("Error: "))) + .await + .is_err() + { + break; + } + } else { + match transaction.1.to_compact(0) { + Ok(compact_tx) => { + if channel_tx + .send(Ok(compact_tx)) + .await + .is_err() + { + break; + } + } + Err(e) => { + // TODO: Hide server error from clients before release. Currently useful for dev purposes. + if channel_tx + .send(Err(tonic::Status::unknown(e.to_string()))) + .await + .is_err() + { + break; + } + } + } + } + } + Err(e) => { + // TODO: Hide server error from clients before release. Currently useful for dev purposes. + if channel_tx + .send(Err(tonic::Status::unknown(e.to_string()))) + .await + .is_err() + { + break; + } + } + } + } + GetRawTransaction::Raw(_) => { + if channel_tx + .send(Err(tonic::Status::internal( + "Error: Received raw transaction type, this should not be impossible.", + ))) + .await + .is_err() + { + break; + } + } + } + } + }, + ) + .await; + match timeout { + Ok(_) => {} + Err(_) => { + channel_tx + .send(Err(tonic::Status::internal( + "Error: get_mempool_tx gRPC request timed out", + ))) + .await + .ok(); + } + } + }); + + Ok(CompactTransactionStream::new(channel_rx)) } /// Return a stream of current Mempool transactions. This will keep the output stream open while /// there are mempool transactions. It will close the returned stream when a new block is mined. - /// - /// NOTE: To be implemented with the mempool updgrade. async fn get_mempool_stream(&self) -> Result { - Err(FetchServiceError::TonicStatusError(tonic::Status::new( - tonic::Code::Unimplemented, - "get_mempool_stream is not implemented in Zaino.", - ))) + let mut mempool = self.mempool.clone(); + let service_timeout = self.config.service_timeout; + let (channel_tx, channel_rx) = + tokio::sync::mpsc::channel(self.config.service_channel_size as usize); + let mempool_height = self.fetcher.get_blockchain_info().await?.blocks.0; + tokio::spawn(async move { + let timeout = timeout( + std::time::Duration::from_secs((service_timeout*6) as u64), + async { + let (mut mempool_stream, _mempool_handle) = + match mempool.get_mempool_stream().await { + Ok(stream) => stream, + Err(e) => { + eprintln!("Error getting mempool stream: {:?}", e); + channel_tx + .send(Err(tonic::Status::internal( + "Error getting mempool stream", + ))) + .await + .ok(); + return; + } + }; + loop { + while let Some(result) = mempool_stream.recv().await { + match result { + Ok((_mempool_key, mempool_value)) => { + match mempool_value.0 { + GetRawTransaction::Object(transaction_object) => { + if channel_tx + .send(Ok(RawTransaction { + data: transaction_object.hex.as_ref().to_vec(), + height: mempool_height as u64, + })) + .await + .is_err() + { + break; + } + } + GetRawTransaction::Raw(_) => { + if channel_tx + .send(Err(tonic::Status::internal( + "Error: Received raw transaction type, this should not be impossible.", + ))) + .await + .is_err() + { + break; + } + } + } + } + Err(e) => { + channel_tx + .send(Err(tonic::Status::internal(format!( + "Error in mempool stream: {:?}", + e + )))) + .await + .ok(); + break; + } + } + } + } + }, + ) + .await; + match timeout { + Ok(_) => {} + Err(_) => { + channel_tx + .send(Err(tonic::Status::internal( + "Error: get_mempool_stream gRPC request timed out", + ))) + .await + .ok(); + } + } + }); + + Ok(RawTransactionStream::new(channel_rx)) } /// GetTreeState returns the note commitment tree state corresponding to the given block. @@ -1194,7 +1429,7 @@ impl LightWalletIndexer for FetchService { tokio::sync::mpsc::channel(self.config.service_channel_size as usize); tokio::spawn(async move { let timeout = timeout( - std::time::Duration::from_secs(service_timeout as u64), + std::time::Duration::from_secs((service_timeout*4) as u64), async { for subtree in subtrees.subtrees { match fetch_service_clone @@ -1384,7 +1619,7 @@ impl LightWalletIndexer for FetchService { tokio::sync::mpsc::channel(self.config.service_channel_size as usize); tokio::spawn(async move { let timeout = timeout( - std::time::Duration::from_secs(service_timeout as u64), + std::time::Duration::from_secs((service_timeout*4) as u64), async { let mut entries: u32 = 0; for utxo in utxos { @@ -1502,7 +1737,7 @@ impl LightWalletIndexer for FetchService { } } -impl FetchService { +impl FetchServiceSubscriber { /// Fetches CompactBlock from the validator. /// /// Uses 2 calls as z_get_block verbosity=1 is required to fetch txids from zcashd. @@ -1657,7 +1892,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); assert_eq!(fetch_service.status(), StatusType::Ready); dbg!(fetch_service.data.clone()); @@ -1706,7 +1942,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let fetch_service_balance = fetch_service .z_get_address_balance(AddressStrings::new_valid(vec![recipient_address]).unwrap()) @@ -1747,7 +1984,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); dbg!(fetch_service .z_get_block("1".to_string(), Some(0)) @@ -1779,7 +2017,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); dbg!(fetch_service .z_get_block("1".to_string(), Some(1)) @@ -1798,16 +2037,41 @@ mod tests { let mut test_manager = TestManager::launch(validator, None, None, true, true) .await .unwrap(); - let clients = test_manager .clients .as_ref() .expect("Clients are not initialized"); + let zebra_uri = format!("http://127.0.0.1:{}", test_manager.zebrad_rpc_listen_port) + .parse::() + .expect("Failed to convert URL to URI"); + + let fetch_service = FetchService::spawn(FetchServiceConfig::new( + SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + test_manager.zebrad_rpc_listen_port, + ), + None, + None, + None, + None, + Network::new_regtest(Some(1), Some(1)), + )) + .await + .unwrap(); + let fetch_service_subscriber = fetch_service.subscriber(); + + let json_service = JsonRpcConnector::new( + zebra_uri, + Some("xxxxxx".to_string()), + Some("xxxxxx".to_string()), + ) + .await + .unwrap(); test_manager.local_net.generate_blocks(1).await.unwrap(); clients.faucet.do_sync(true).await.unwrap(); - let tx_1 = zingolib::testutils::lightclient::from_inputs::quick_send( + zingolib::testutils::lightclient::from_inputs::quick_send( &clients.faucet, vec![( &clients.get_recipient_address("transparent").await, @@ -1817,7 +2081,7 @@ mod tests { ) .await .unwrap(); - let tx_2 = zingolib::testutils::lightclient::from_inputs::quick_send( + zingolib::testutils::lightclient::from_inputs::quick_send( &clients.faucet, vec![( &clients.get_recipient_address("unified").await, @@ -1828,28 +2092,14 @@ mod tests { .await .unwrap(); - let fetch_service = FetchService::spawn(FetchServiceConfig::new( - SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), - test_manager.zebrad_rpc_listen_port, - ), - None, - None, - None, - None, - Network::new_regtest(Some(1), Some(1)), - )) - .await - .unwrap(); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - let fetch_service_mempool = fetch_service.get_raw_mempool().await.unwrap(); + let mut fetch_service_mempool = fetch_service_subscriber.get_raw_mempool().await.unwrap(); + let mut json_service_mempool = json_service.get_raw_mempool().await.unwrap().transactions; - dbg!(&tx_1); - dbg!(&tx_2); dbg!(&fetch_service_mempool); - - assert_eq!(tx_1.first().to_string(), fetch_service_mempool[0]); - assert_eq!(tx_2.first().to_string(), fetch_service_mempool[1]); + dbg!(&json_service_mempool); + assert_eq!(json_service_mempool.sort(), fetch_service_mempool.sort()); test_manager.close().await; } @@ -1895,7 +2145,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); dbg!(fetch_service .z_get_treestate("2".to_string()) @@ -1946,7 +2197,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); dbg!(fetch_service .z_get_subtrees_by_index("orchard".to_string(), NoteCommitmentSubtreeIndex(0), None) @@ -1997,7 +2249,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); dbg!(fetch_service .get_raw_transaction(tx.first().to_string(), Some(1)) @@ -2044,7 +2297,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let fetch_service_txids = fetch_service .get_address_tx_ids(GetAddressTxIdsRequest::from_parts( @@ -2100,7 +2354,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let fetch_service_utxos = fetch_service .z_get_address_utxos(AddressStrings::new_valid(vec![recipient_address]).unwrap()) @@ -2140,7 +2395,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2188,7 +2444,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2237,7 +2494,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2292,7 +2550,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2362,7 +2621,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2436,7 +2696,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2507,7 +2768,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2578,7 +2840,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2668,7 +2931,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2708,6 +2972,246 @@ mod tests { test_manager.close().await; } + #[tokio::test] + async fn fetch_service_get_mempool_tx_zcashd() { + fetch_service_get_mempool_tx("zcashd").await; + } + + async fn fetch_service_get_mempool_tx(validator: &str) { + let mut test_manager = TestManager::launch(validator, None, None, true, true) + .await + .unwrap(); + let zebra_uri = format!("http://127.0.0.1:{}", test_manager.zebrad_rpc_listen_port) + .parse::() + .expect("Failed to convert URL to URI"); + let clients = test_manager + .clients + .as_ref() + .expect("Clients are not initialized"); + + let fetch_service = FetchService::spawn(FetchServiceConfig::new( + SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + test_manager.zebrad_rpc_listen_port, + ), + None, + None, + None, + None, + Network::new_regtest(Some(1), Some(1)), + )) + .await + .unwrap(); + let fetch_service_subscriber = fetch_service.subscriber(); + + let grpc_service = zaino_serve::rpc::GrpcClient { + zebrad_rpc_uri: zebra_uri, + online: test_manager.online.clone(), + }; + + test_manager.local_net.generate_blocks(1).await.unwrap(); + clients.faucet.do_sync(true).await.unwrap(); + + let tx_1 = zingolib::testutils::lightclient::from_inputs::quick_send( + &clients.faucet, + vec![( + &clients.get_recipient_address("transparent").await, + 250_000, + None, + )], + ) + .await + .unwrap(); + let tx_2 = zingolib::testutils::lightclient::from_inputs::quick_send( + &clients.faucet, + vec![( + &clients.get_recipient_address("unified").await, + 250_000, + None, + )], + ) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + let exclude_list_empty = Exclude { txid: Vec::new()}; + + let fetch_service_stream = fetch_service_subscriber + .get_mempool_tx(exclude_list_empty.clone()) + .await + .unwrap(); + let fetch_service_mempool_tx: Vec<_> = fetch_service_stream.collect().await; + let grpc_service_stream = grpc_service + .get_mempool_tx(tonic::Request::new(exclude_list_empty)) + .await + .unwrap() + .into_inner(); + let grpc_service_mempool_tx: Vec<_> = grpc_service_stream.collect().await; + + let fetch_mempool_tx: Vec<_> = fetch_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect(); + let grpc_mempool_tx: Vec<_> = grpc_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect(); + + let mut sorted_fetch_mempool_tx = fetch_mempool_tx.clone(); + sorted_fetch_mempool_tx.sort_by_key(|tx| tx.hash.clone()); + let mut sorted_grpc_mempool_tx = grpc_mempool_tx; + sorted_grpc_mempool_tx.sort_by_key(|tx| tx.hash.clone()); + + let mut tx1_bytes = tx_1.first().as_ref().clone(); + tx1_bytes.reverse(); + let mut tx2_bytes = tx_2.first().as_ref().clone(); + tx2_bytes.reverse(); + + let mut sorted_txids = vec![tx1_bytes, tx2_bytes]; + sorted_txids.sort_by_key(|hash| hash.clone()); + + assert_eq!(sorted_fetch_mempool_tx, sorted_grpc_mempool_tx); + assert_eq!(sorted_fetch_mempool_tx[0].hash, sorted_txids[0]); + assert_eq!(sorted_fetch_mempool_tx[1].hash, sorted_txids[1]); + + let exclude_list = Exclude { txid: vec![sorted_txids[0][..8].to_vec()]}; + + let exclude_fetch_service_stream = fetch_service_subscriber + .get_mempool_tx(exclude_list.clone()) + .await + .unwrap(); + let exclude_fetch_service_mempool_tx: Vec<_> = exclude_fetch_service_stream.collect().await; + let exclude_grpc_service_stream = grpc_service + .get_mempool_tx(tonic::Request::new(exclude_list)) + .await + .unwrap() + .into_inner(); + let exclude_grpc_service_mempool_tx: Vec<_> = exclude_grpc_service_stream.collect().await; + + let exclude_fetch_mempool_tx: Vec<_> = exclude_fetch_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect(); + let exclude_grpc_mempool_tx: Vec<_> = exclude_grpc_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect(); + + let mut sorted_exclude_fetch_mempool_tx = exclude_fetch_mempool_tx.clone(); + sorted_exclude_fetch_mempool_tx.sort_by_key(|tx| tx.hash.clone()); + let mut sorted_exclude_grpc_mempool_tx = exclude_grpc_mempool_tx; + sorted_exclude_grpc_mempool_tx.sort_by_key(|tx| tx.hash.clone()); + + assert_eq!(sorted_exclude_fetch_mempool_tx, sorted_exclude_grpc_mempool_tx); + assert_eq!(sorted_exclude_fetch_mempool_tx[0].hash, sorted_txids[1]); + + test_manager.close().await; + } + + #[tokio::test] + async fn fetch_service_get_mempool_stream_zcashd() { + fetch_service_get_mempool_stream("zcashd").await; + } + + async fn fetch_service_get_mempool_stream(validator: &str) { + let mut test_manager = TestManager::launch(validator, None, None, true, true) + .await + .unwrap(); + let zebra_uri = format!("http://127.0.0.1:{}", test_manager.zebrad_rpc_listen_port) + .parse::() + .expect("Failed to convert URL to URI"); + let clients = test_manager + .clients + .as_ref() + .expect("Clients are not initialized"); + + let fetch_service = FetchService::spawn(FetchServiceConfig::new( + SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + test_manager.zebrad_rpc_listen_port, + ), + None, + None, + None, + None, + Network::new_regtest(Some(1), Some(1)), + )) + .await + .unwrap(); + let fetch_service_subscriber = fetch_service.subscriber(); + let grpc_service = zaino_serve::rpc::GrpcClient { + zebrad_rpc_uri: zebra_uri, + online: test_manager.online.clone(), + }; + + test_manager.local_net.generate_blocks(1).await.unwrap(); + clients.faucet.do_sync(true).await.unwrap(); + + let fetch_service_handle = tokio::spawn(async move { + let fetch_service_stream = fetch_service_subscriber + .get_mempool_stream() + .await + .unwrap(); + let fetch_service_mempool_tx: Vec<_> = fetch_service_stream.collect().await; + fetch_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect::>() + }); + let grpc_service_handle = tokio::spawn(async move { + let grpc_service_stream = grpc_service.get_mempool_stream( + tonic::Request::new( + zaino_proto::proto::service::Empty {}, + )).await + .unwrap() + .into_inner(); + let grpc_service_mempool_tx: Vec<_> = grpc_service_stream.collect().await; + grpc_service_mempool_tx + .into_iter() + .filter_map(|result| result.ok()) + .collect::>() + }); + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + zingolib::testutils::lightclient::from_inputs::quick_send( + &clients.faucet, + vec![( + &clients.get_recipient_address("transparent").await, + 250_000, + None, + )], + ) + .await + .unwrap(); + zingolib::testutils::lightclient::from_inputs::quick_send( + &clients.faucet, + vec![( + &clients.get_recipient_address("unified").await, + 250_000, + None, + )], + ) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + test_manager.local_net.generate_blocks(1).await.unwrap(); + + let fetch_mempool_tx = fetch_service_handle.await.unwrap(); + let grpc_mempool_tx = grpc_service_handle.await.unwrap(); + + let mut sorted_fetch_mempool_tx = fetch_mempool_tx.clone(); + sorted_fetch_mempool_tx.sort_by_key(|tx| tx.data.clone()); + let mut sorted_grpc_mempool_tx = grpc_mempool_tx; + sorted_grpc_mempool_tx.sort_by_key(|tx| tx.data.clone()); + + assert_eq!(sorted_fetch_mempool_tx, sorted_grpc_mempool_tx); + + test_manager.close().await; + } + #[tokio::test] async fn fetch_service_get_tree_state_zcashd() { fetch_service_get_tree_state("zcashd").await; @@ -2733,7 +3237,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2784,7 +3289,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2828,7 +3334,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2897,7 +3404,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -2971,7 +3479,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -3043,7 +3552,8 @@ mod tests { Network::new_regtest(Some(1), Some(1)), )) .await - .unwrap(); + .unwrap() + .subscriber(); let grpc_service = zaino_serve::rpc::GrpcClient { zebrad_rpc_uri: zebra_uri, online: test_manager.online.clone(), @@ -3056,11 +3566,14 @@ mod tests { .unwrap()) .into_inner(); - // Clean build date from responses. + // Clean build date and git commit from responses. let mut fetch_service_cleaned_info = fetch_service_get_lightd_info.clone(); let mut grpc_service_cleaned_info = grpc_service_get_lightd_info.clone(); fetch_service_cleaned_info.build_date = String::new(); grpc_service_cleaned_info.build_date = String::new(); + fetch_service_cleaned_info.git_commit = String::new(); + grpc_service_cleaned_info.git_commit = String::new(); + assert_eq!(fetch_service_cleaned_info, grpc_service_cleaned_info); diff --git a/zaino-state/src/lib.rs b/zaino-state/src/lib.rs index e08c329..88363ee 100644 --- a/zaino-state/src/lib.rs +++ b/zaino-state/src/lib.rs @@ -5,10 +5,12 @@ use zebra_chain::parameters::Network; +pub mod broadcast; pub mod config; pub mod error; pub mod fetch; pub mod indexer; +pub mod mempool; pub mod state; pub mod status; pub mod stream; diff --git a/zaino-state/src/mempool.rs b/zaino-state/src/mempool.rs new file mode 100644 index 0000000..1341b0d --- /dev/null +++ b/zaino-state/src/mempool.rs @@ -0,0 +1,400 @@ +//! Holds Zaino's mempool implementation. + +use std::collections::HashSet; + +use crate::{ + broadcast::{Broadcast, BroadcastSubscriber}, + error::{MempoolError, StatusError}, + status::{AtomicStatus, StatusType}, +}; +use zaino_fetch::jsonrpc::connector::JsonRpcConnector; +use zebra_chain::block::Hash; +use zebra_rpc::methods::GetRawTransaction; + +/// Mempool key +/// +/// Holds txid. +#[derive(Debug, Clone, PartialEq, Hash, Eq)] +pub struct MempoolKey(pub String); + +/// Mempool value. +/// +/// NOTE: Currently holds a copy of txid, +/// this could be updated to store the corresponding transaction as the value, +/// this would enable the serving of mempool trasactions directly, significantly increasing efficiency. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MempoolValue(pub GetRawTransaction); + +/// Zcash mempool, uses dashmap for efficient serving of mempool tx. +#[derive(Debug)] +pub struct Mempool { + /// Zcash chain fetch service. + fetcher: JsonRpcConnector, + /// Wrapper for a dashmap of mempool transactions. + state: Broadcast, + /// Mempool sync handle. + sync_task_handle: Option>, + /// mempool status. + status: AtomicStatus, +} + +impl Mempool { + /// Spawns a new [`Mempool`]. + pub async fn spawn( + fetcher: &JsonRpcConnector, + capacity_and_shard_amount: Option<(usize, usize)>, + ) -> Result { + let mut mempool = Mempool { + fetcher: fetcher.clone(), + state: match capacity_and_shard_amount { + Some((capacity, shard_amount)) => Broadcast::new_custom(capacity, shard_amount), + None => Broadcast::new_default(), + }, + sync_task_handle: None, + status: AtomicStatus::new(StatusType::Spawning.into()), + }; + + loop { + match mempool.get_mempool_transactions().await { + Ok(mempool_transactions) => { + mempool.status.store(StatusType::Ready.into()); + mempool + .state + .insert_filtered_set(mempool_transactions, mempool.status.clone().into()); + break; + } + Err(e) => { + mempool.status.store(StatusType::Spawning.into()); + mempool.state.notify(mempool.status.clone().into()); + eprintln!("{e}"); + continue; + } + }; + } + + mempool.sync_task_handle = Some(mempool.serve().await?); + + Ok(mempool) + } + + async fn serve(&self) -> Result, MempoolError> { + let mempool = self.clone(); + let state = self.state.clone(); + let status = self.status.clone(); + status.store(StatusType::Spawning.into()); + + let sync_handle = tokio::spawn(async move { + let mut best_block_hash: Hash; + let mut check_block_hash: Hash; + + loop { + match mempool.fetcher.get_blockchain_info().await { + Ok(chain_info) => { + best_block_hash = chain_info.best_block_hash.clone(); + break; + } + Err(e) => { + state.notify(status.clone().into()); + eprintln!("{e}"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + continue; + } + } + } + + loop { + match mempool.fetcher.get_blockchain_info().await { + Ok(chain_info) => { + check_block_hash = chain_info.best_block_hash.clone(); + } + Err(e) => { + status.store(StatusType::RecoverableError.into()); + state.notify(status.clone().into()); + eprintln!("{e}"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + continue; + } + } + + if check_block_hash != best_block_hash { + best_block_hash = check_block_hash; + status.store(StatusType::Syncing.into()); + state.notify(status.clone().into()); + state.clear(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + continue; + } + + match mempool.get_mempool_transactions().await { + Ok(mempool_transactions) => { + status.store(StatusType::Ready.into()); + state.insert_filtered_set(mempool_transactions, status.clone().into()); + } + Err(e) => { + status.store(StatusType::RecoverableError.into()); + state.notify(status.clone().into()); + eprintln!("{e}"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + continue; + } + }; + + if status.load() == StatusType::Closing as usize { + state.notify(status.into()); + return; + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + }); + + Ok(sync_handle) + } + + /// Returns all transactions in the mempool. + async fn get_mempool_transactions( + &self, + ) -> Result, MempoolError> { + let mut transactions = Vec::new(); + + for txid in self.fetcher.get_raw_mempool().await?.transactions { + let transaction = self + .fetcher + .get_raw_transaction(txid.clone(), Some(1)) + .await?; + //process txid + transactions.push((MempoolKey(txid), MempoolValue(transaction.into()))); + } + + Ok(transactions) + } + + /// Returns a [`MempoolSubscriber`]. + pub fn subscriber(&self) -> MempoolSubscriber { + MempoolSubscriber { + subscriber: self.state.subscriber(), + seen_txids: HashSet::new(), + status: self.status.clone(), + } + } + + /// Returns the status of the mempool. + pub fn status(&self) -> StatusType { + self.status.load().into() + } + + /// Sets the mempool to close gracefully. + pub fn close(&mut self) { + self.status.store(StatusType::Closing.into()); + self.state.notify(self.status()); + if let Some(handle) = self.sync_task_handle.take() { + handle.abort(); + } + } +} + +impl Drop for Mempool { + fn drop(&mut self) { + self.status.store(StatusType::Closing.into()); + self.state.notify(StatusType::Closing); + if let Some(handle) = self.sync_task_handle.take() { + handle.abort(); + } + } +} + +impl Clone for Mempool { + fn clone(&self) -> Self { + Self { + fetcher: self.fetcher.clone(), + state: self.state.clone(), + sync_task_handle: None, + status: self.status.clone(), + } + } +} + +/// A subscriber to a [`Mempool`]. +#[derive(Debug, Clone)] +pub struct MempoolSubscriber { + subscriber: BroadcastSubscriber, + seen_txids: HashSet, + status: AtomicStatus, +} + +impl MempoolSubscriber { + /// Returns all tx currently in the mempool. + pub async fn get_mempool(&self) -> Vec<(MempoolKey, MempoolValue)> { + self.subscriber.get_filtered_state(&HashSet::new()) + } + + /// Returns all tx currently in the mempool filtered by [`exclude_list`]. + /// + /// The transaction IDs in the Exclude list can be shortened to any number of bytes to make the request + /// more bandwidth-efficient; if two or more transactions in the mempool + /// match a shortened txid, they are all sent (none is excluded). Transactions + /// in the exclude list that don't exist in the mempool are ignored. + pub async fn get_filtered_mempool( + &self, + exclude_list: Vec, + ) -> Vec<(MempoolKey, MempoolValue)> { + let mempool_tx = self.subscriber.get_filtered_state(&HashSet::new()); + + let mempool_txids: HashSet = mempool_tx + .iter() + .map(|(mempool_key, _)| mempool_key.0.clone()) + .collect(); + + let mut txids_to_exclude: HashSet = HashSet::new(); + for exclude_txid in &exclude_list { + // Convert to big endian (server format). + let server_exclude_txid: String = exclude_txid + .chars() + .collect::>() + .chunks(2) + .rev() + .map(|chunk| chunk.iter().collect::()) + .collect(); + let matching_txids: Vec<&String> = mempool_txids + .iter() + .filter(|txid| txid.starts_with(&server_exclude_txid)) + .collect(); + + if matching_txids.len() == 1 { + txids_to_exclude.insert(MempoolKey(matching_txids[0].clone())); + } + } + + mempool_tx + .into_iter() + .filter(|(mempool_key, _)| !txids_to_exclude.contains(mempool_key)) + .collect() + } + + /// Returns a stream of mempool txids, closes the channel when a new block has been mined. + pub async fn get_mempool_stream( + &mut self, + ) -> Result< + ( + tokio::sync::mpsc::Receiver>, + tokio::task::JoinHandle<()>, + ), + MempoolError, + > { + let mut subscriber = self.clone(); + subscriber.seen_txids.clear(); + let (channel_tx, channel_rx) = tokio::sync::mpsc::channel(32); + + let streamer_handle = tokio::spawn(async move { + let mempool_result: Result<(), MempoolError> = async { + loop { + let (mempool_status, mempool_updates) = subscriber.wait_on_update().await?; + match mempool_status { + StatusType::Ready => { + for (mempool_key, mempool_value) in mempool_updates { + loop { + match channel_tx + .try_send(Ok((mempool_key.clone(), mempool_value.clone()))) + { + Ok(_) => break, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + tokio::time::sleep(std::time::Duration::from_millis( + 100, + )) + .await; + continue; + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + return Ok(()); + } + } + } + } + } + StatusType::Syncing => { + return Ok(()); + } + StatusType::Closing => { + return Err(MempoolError::StatusError(StatusError( + StatusType::Closing, + ))); + } + StatusType::RecoverableError => { + continue; + } + status => { + return Err(MempoolError::StatusError(StatusError(status))); + } + } + if subscriber.status.load() == StatusType::Closing as usize { + return Err(MempoolError::StatusError(StatusError(StatusType::Closing))); + } + } + } + .await; + + if let Err(mempool_error) = mempool_result { + eprintln!("Error in mempool stream: {:?}", mempool_error); + match mempool_error { + MempoolError::StatusError(error_status) => { + let _ = channel_tx.send(Err(error_status)).await; + } + _ => { + let _ = channel_tx + .send(Err(StatusError(StatusType::RecoverableError))) + .await; + } + } + } + }); + + Ok((channel_rx, streamer_handle)) + } + + /// Returns the status of the mempool. + pub fn status(&self) -> StatusType { + self.status.load().into() + } + + /// Returns all tx currently in the mempool and updates seen_txids. + fn get_mempool_and_update_seen(&mut self) -> Vec<(MempoolKey, MempoolValue)> { + let mempool_updates = self.subscriber.get_filtered_state(&HashSet::new()); + for (mempool_key, _) in mempool_updates.clone() { + self.seen_txids.insert(mempool_key); + } + mempool_updates + } + + /// Returns txids not yet seen by the subscriber and updates seen_txids. + fn get_mempool_updates_and_update_seen(&mut self) -> Vec<(MempoolKey, MempoolValue)> { + let mempool_updates = self.subscriber.get_filtered_state(&self.seen_txids); + for (mempool_key, _) in mempool_updates.clone() { + self.seen_txids.insert(mempool_key); + } + mempool_updates + } + + /// Waits on update from mempool and updates the mempool, returning either the new mempool or the mempool updates, along with the mempool status. + async fn wait_on_update( + &mut self, + ) -> Result<(StatusType, Vec<(MempoolKey, MempoolValue)>), MempoolError> { + let update_status = self.subscriber.wait_on_notifier().await?; + match update_status { + StatusType::Ready => Ok(( + StatusType::Ready, + self.get_mempool_updates_and_update_seen(), + )), + StatusType::Syncing => { + self.clear_seen(); + Ok((StatusType::Syncing, self.get_mempool_and_update_seen())) + } + StatusType::Closing => Ok((StatusType::Closing, Vec::new())), + status => return Err(MempoolError::StatusError(StatusError(status))), + } + } + + /// Clears the subscribers seen_txids. + fn clear_seen(&mut self) { + self.seen_txids.clear(); + } +} diff --git a/zaino-state/src/state.rs b/zaino-state/src/state.rs index c7ab31d..828b86f 100644 --- a/zaino-state/src/state.rs +++ b/zaino-state/src/state.rs @@ -53,7 +53,7 @@ pub struct StateService { /// Monitors changes in the chain tip. _chain_tip_change: ChainTipChange, /// Sync task handle. - sync_task_handle: tokio::task::JoinHandle<()>, + sync_task_handle: Option>, /// JsonRPC Client. _rpc_client: JsonRpcConnector, /// Service metadata. @@ -102,7 +102,7 @@ impl StateService { read_state_service, latest_chain_tip, _chain_tip_change: chain_tip_change, - sync_task_handle, + sync_task_handle: Some(sync_task_handle), _rpc_client: rpc_client, data, config, @@ -178,13 +178,34 @@ impl StateService { /// Shuts down the StateService. pub fn close(&mut self) { - self.sync_task_handle.abort(); + if self.sync_task_handle.is_some() { + if let Some(handle) = self.sync_task_handle.take() { + handle.abort(); + } + } } } impl Drop for StateService { fn drop(&mut self) { - self.close() + if let Some(handle) = self.sync_task_handle.take() { + handle.abort(); + } + } +} + +impl Clone for StateService { + fn clone(&self) -> Self { + Self { + read_state_service: self.read_state_service.clone(), + latest_chain_tip: self.latest_chain_tip.clone(), + _chain_tip_change: self._chain_tip_change.clone(), + sync_task_handle: None, + _rpc_client: self._rpc_client.clone(), + data: self.data.clone(), + config: self.config.clone(), + status: self.status.clone(), + } } }