diff --git a/crates/shared/src/price_estimation/native_price_cache.rs b/crates/shared/src/price_estimation/native_price_cache.rs index 496f9340fd..04c45a5d30 100644 --- a/crates/shared/src/price_estimation/native_price_cache.rs +++ b/crates/shared/src/price_estimation/native_price_cache.rs @@ -70,6 +70,41 @@ struct CachedResult { result: CacheEntry, updated_at: Instant, requested_at: Instant, + accumulative_errors_count: u32, +} + +/// Defines how many consecutive errors are allowed before the cache starts +/// returning the error to the user without trying to fetch the price from the +/// estimator. +const ACCUMULATIVE_ERRORS_THRESHOLD: u32 = 5; + +impl CachedResult { + fn new( + result: CacheEntry, + updated_at: Instant, + requested_at: Instant, + current_accumulative_errors_count: u32, + ) -> Self { + let estimator_internal_errors_count = + matches!(result, Err(PriceEstimationError::EstimatorInternal(_))) + .then_some(current_accumulative_errors_count + 1) + .unwrap_or_default(); + + Self { + result, + updated_at, + requested_at, + accumulative_errors_count: estimator_internal_errors_count, + } + } + + /// The result is not ready if the estimator has returned an internal error + /// and consecutive errors are less than + /// `ESTIMATOR_INTERNAL_ERRORS_THRESHOLD`. + fn is_ready(&self) -> bool { + !matches!(self.result, Err(PriceEstimationError::EstimatorInternal(_))) + || self.accumulative_errors_count >= ACCUMULATIVE_ERRORS_THRESHOLD + } } impl Inner { @@ -80,13 +115,13 @@ impl Inner { cache: &mut MutexGuard>, max_age: &Duration, create_missing_entry: bool, - ) -> Option { + ) -> Option { match cache.entry(token) { Entry::Occupied(mut entry) => { let entry = entry.get_mut(); entry.requested_at = now; let is_recent = now.saturating_duration_since(entry.updated_at) < *max_age; - is_recent.then_some(entry.result.clone()) + is_recent.then_some(entry.clone()) } Entry::Vacant(entry) => { if create_missing_entry { @@ -95,19 +130,31 @@ impl Inner { // This should happen only for prices missing while building the auction. // Otherwise malicious actors could easily cause the cache size to blow up. let outdated_timestamp = now.checked_sub(*max_age).unwrap(); - entry.insert(CachedResult { - result: Ok(0.), - updated_at: outdated_timestamp, - requested_at: now, - }); + entry.insert(CachedResult::new( + Ok(0.), + outdated_timestamp, + now, + Default::default(), + )); } None } } } + fn get_ready_to_use_cached_price( + token: H160, + now: Instant, + cache: &mut MutexGuard>, + max_age: &Duration, + create_missing_entry: bool, + ) -> Option { + Self::get_cached_price(token, now, cache, max_age, create_missing_entry) + .filter(|cached| cached.is_ready()) + } + /// Checks cache for the given tokens one by one. If the price is already - /// cached it gets returned. If it's not in the cache a new price + /// cached, it gets returned. If it's not in the cache, a new price /// estimation request gets issued. We check the cache before each /// request because they can take a long time and some other task might /// have fetched some requested price in the meantime. @@ -117,15 +164,19 @@ impl Inner { max_age: Duration, ) -> futures::stream::BoxStream<'a, (H160, NativePriceEstimateResult)> { let estimates = tokens.iter().map(move |token| async move { - { - // check if price is cached by now + let current_accumulative_errors_count = { + // check if the price is cached by now let now = Instant::now(); let mut cache = self.cache.lock().unwrap(); - let price = Self::get_cached_price(*token, now, &mut cache, &max_age, false); - if let Some(price) = price { - return (*token, price); + + match Self::get_cached_price(*token, now, &mut cache, &max_age, false) { + Some(cached) if cached.is_ready() => { + return (*token, cached.result); + } + Some(cached) => cached.accumulative_errors_count, + None => Default::default(), } - } + }; let result = self.estimator.estimate_native_price(*token).await; @@ -133,13 +184,10 @@ impl Inner { if should_cache(&result) { let now = Instant::now(); let mut cache = self.cache.lock().unwrap(); + cache.insert( *token, - CachedResult { - result: result.clone(), - updated_at: now, - requested_at: now, - }, + CachedResult::new(result.clone(), now, now, current_accumulative_errors_count), ); }; @@ -178,10 +226,11 @@ fn should_cache(result: &Result) -> bool { match result { Ok(_) | Err(PriceEstimationError::NoLiquidity { .. }) - | Err(PriceEstimationError::UnsupportedToken { .. }) => true, - Err(PriceEstimationError::EstimatorInternal(_)) - | Err(PriceEstimationError::ProtocolInternal(_)) - | Err(PriceEstimationError::RateLimited) => false, + | Err(PriceEstimationError::UnsupportedToken { .. }) + | Err(PriceEstimationError::EstimatorInternal(_)) => true, + Err(PriceEstimationError::ProtocolInternal(_)) | Err(PriceEstimationError::RateLimited) => { + false + } Err(PriceEstimationError::UnsupportedOrderType(_)) => { tracing::error!(?result, "Unexpected error in native price cache"); false @@ -241,11 +290,12 @@ impl CachingNativePriceEstimator { Some(( token, - CachedResult { - result: Ok(from_normalized_price(price)?), + CachedResult::new( + Ok(from_normalized_price(price)?), updated_at, - requested_at: now, - }, + now, + Default::default(), + ), )) }) .collect::>(); @@ -300,14 +350,20 @@ impl CachingNativePriceEstimator { let mut cache = self.0.cache.lock().unwrap(); let mut results = HashMap::default(); for token in tokens { - let cached = Inner::get_cached_price(*token, now, &mut cache, &self.0.max_age, true); + let cached = Inner::get_ready_to_use_cached_price( + *token, + now, + &mut cache, + &self.0.max_age, + true, + ); let label = if cached.is_some() { "hits" } else { "misses" }; Metrics::get() .native_price_cache_access .with_label_values(&[label]) .inc_by(1); if let Some(result) = cached { - results.insert(*token, result); + results.insert(*token, result.result); } } results @@ -359,7 +415,7 @@ impl NativePriceEstimating for CachingNativePriceEstimator { let cached = { let now = Instant::now(); let mut cache = self.0.cache.lock().unwrap(); - Inner::get_cached_price(token, now, &mut cache, &self.0.max_age, false) + Inner::get_ready_to_use_cached_price(token, now, &mut cache, &self.0.max_age, false) }; let label = if cached.is_some() { "hits" } else { "misses" }; @@ -368,8 +424,8 @@ impl NativePriceEstimating for CachingNativePriceEstimator { .with_label_values(&[label]) .inc_by(1); - if let Some(price) = cached { - return price; + if let Some(cached) = cached { + return cached.result; } self.0 @@ -391,6 +447,7 @@ mod tests { native::{MockNativePriceEstimating, NativePriceEstimating}, PriceEstimationError, }, + anyhow::anyhow, futures::FutureExt, num::ToPrimitive, }; @@ -485,6 +542,107 @@ mod tests { } } + #[tokio::test] + async fn properly_caches_accumulative_errors() { + let mut inner = MockNativePriceEstimating::new(); + let mut seq = mockall::Sequence::new(); + + // First 3 calls: Return EstimatorInternal error. Increment the errors counter. + inner + .expect_estimate_native_price() + .times(3) + .in_sequence(&mut seq) + .returning(|_| { + async { Err(PriceEstimationError::EstimatorInternal(anyhow!("boom"))) }.boxed() + }); + + // Next 1 call: Return Ok(1.0). This resets the errors counter. + inner + .expect_estimate_native_price() + .once() + .in_sequence(&mut seq) + .returning(|_| async { Ok(1.0) }.boxed()); + + // Next 2 calls: Return EstimatorInternal error. Start incrementing the errors + // counter from the beginning. + inner + .expect_estimate_native_price() + .times(2) + .in_sequence(&mut seq) + .returning(|_| { + async { Err(PriceEstimationError::EstimatorInternal(anyhow!("boom"))) }.boxed() + }); + + // Next call: Return a recoverable error, which doesn't affect the errors + // counter. + inner + .expect_estimate_native_price() + .once() + .in_sequence(&mut seq) + .returning(|_| async { Err(PriceEstimationError::RateLimited) }.boxed()); + + // Since the ACCUMULATIVE_ERRORS_THRESHOLD is 5, there are only 3 more calls + // remain. Anything exceeding that must return the cached value. + inner + .expect_estimate_native_price() + .times(3) + .in_sequence(&mut seq) + .returning(|_| { + async { Err(PriceEstimationError::EstimatorInternal(anyhow!("boom"))) }.boxed() + }); + + let estimator = CachingNativePriceEstimator::new( + Box::new(inner), + Duration::from_millis(100), + Duration::from_millis(200), + None, + Default::default(), + 1, + ); + + // First 3 calls: The cache is not used. Counter gets increased. + for _ in 0..3 { + let result = estimator.estimate_native_price(token(0)).await; + assert!(matches!( + result.as_ref().unwrap_err(), + PriceEstimationError::EstimatorInternal(_) + )); + } + + // Reset the errors counter. + let result = estimator.estimate_native_price(token(0)).await; + assert_eq!(result.as_ref().unwrap().to_i64().unwrap(), 1); + + // Make sure the cached value gets evicted. + tokio::time::sleep(Duration::from_millis(120)).await; + + // Increment the errors counter again. + for _ in 0..2 { + let result = estimator.estimate_native_price(token(0)).await; + assert!(matches!( + result.as_ref().unwrap_err(), + PriceEstimationError::EstimatorInternal(_) + )); + } + + // Receive a recoverable error, which shouldn't affect the counter. + let result = estimator.estimate_native_price(token(0)).await; + assert!(matches!( + result.as_ref().unwrap_err(), + PriceEstimationError::RateLimited + )); + + // Make more than expected calls. The cache should be used once the threshold is + // reached. + for _ in 0..(ACCUMULATIVE_ERRORS_THRESHOLD * 2) { + let result = estimator.estimate_native_price(token(0)).await; + assert!(matches!( + result.as_ref().unwrap_err(), + PriceEstimationError::EstimatorInternal(_) + )); + } + } + #[tokio::test] async fn does_not_cache_recoverable_failed_estimates() { let mut inner = MockNativePriceEstimating::new(); @@ -665,22 +823,8 @@ mod tests { let inner = Inner { cache: Mutex::new( [ - ( - t0, - CachedResult { - result: Ok(0.), - updated_at: now, - requested_at: now, - }, - ), - ( - t1, - CachedResult { - result: Ok(0.), - updated_at: now, - requested_at: now, - }, - ), + (t0, CachedResult::new(Ok(0.), now, now, Default::default())), + (t1, CachedResult::new(Ok(0.), now, now, Default::default())), ] .into_iter() .collect(),