diff --git a/crates/autopilot/src/arguments.rs b/crates/autopilot/src/arguments.rs index 677237ea84..d62d5c0156 100644 --- a/crates/autopilot/src/arguments.rs +++ b/crates/autopilot/src/arguments.rs @@ -4,7 +4,7 @@ use { clap::ValueEnum, primitive_types::H160, shared::{ - arguments::{display_list, display_option, ExternalSolver}, + arguments::{display_list, display_option, Db, ExternalSolver}, bad_token::token_owner_finder, http_client, price_estimation::{self, NativePriceEstimators}, @@ -54,6 +54,10 @@ pub struct Arguments { #[clap(long, env, default_value = "postgresql://")] pub db_url: Url, + /// Url of the Postgres database. + #[clap(flatten)] + pub db: Db, + /// The number of order events to insert in a single batch. #[clap(long, env, default_value = "500")] pub insert_batch_size: NonZeroUsize, @@ -269,6 +273,7 @@ impl std::fmt::Display for Arguments { order_events_cleanup_interval, order_events_cleanup_threshold, db_url, + db, insert_batch_size, native_price_estimation_results_required, max_settlement_transaction_wait, @@ -292,6 +297,8 @@ impl std::fmt::Display for Arguments { writeln!(f, "metrics_address: {}", metrics_address)?; let _intentionally_ignored = db_url; writeln!(f, "db_url: SECRET")?; + let _intentionally_ignored = db; + writeln!(f, "db: SECRET")?; writeln!(f, "skip_event_sync: {}", skip_event_sync)?; writeln!(f, "allowed_tokens: {:?}", allowed_tokens)?; writeln!(f, "unsupported_tokens: {:?}", unsupported_tokens)?; diff --git a/crates/autopilot/src/run.rs b/crates/autopilot/src/run.rs index df4f8371c3..810495bd93 100644 --- a/crates/autopilot/src/run.rs +++ b/crates/autopilot/src/run.rs @@ -127,7 +127,9 @@ pub async fn start(args: impl Iterator) { pub async fn run(args: Arguments) { assert!(args.shadow.is_none(), "cannot run in shadow mode"); - let db = Postgres::new(args.db_url.as_str(), args.insert_batch_size) + let db_url = args.db.to_url().unwrap_or(args.db_url); + + let db = Postgres::new(db_url.as_str(), args.insert_batch_size) .await .unwrap(); crate::database::run_database_metrics_work(db.clone()); diff --git a/crates/orderbook/src/arguments.rs b/crates/orderbook/src/arguments.rs index 2f214c883c..90821052d3 100644 --- a/crates/orderbook/src/arguments.rs +++ b/crates/orderbook/src/arguments.rs @@ -2,7 +2,7 @@ use { primitive_types::H160, reqwest::Url, shared::{ - arguments::{display_option, display_secret_option}, + arguments::{display_option, display_secret_option, Db}, bad_token::token_owner_finder, http_client, price_estimation::{self, NativePriceEstimators}, @@ -40,6 +40,10 @@ pub struct Arguments { #[clap(long, env, default_value = "postgresql://")] pub db_url: Url, + /// Url of the Postgres database. + #[clap(flatten)] + pub db: Db, + /// The minimum amount of time in seconds an order has to be valid for. #[clap( long, @@ -166,6 +170,7 @@ impl std::fmt::Display for Arguments { hooks_contract_address, app_data_size_limit, db_url, + db, max_gas_per_order, } = self; @@ -178,6 +183,8 @@ impl std::fmt::Display for Arguments { writeln!(f, "bind_address: {}", bind_address)?; let _intentionally_ignored = db_url; writeln!(f, "db_url: SECRET")?; + let _intentionally_ignored = db; + writeln!(f, "db: SECRET")?; writeln!( f, "min_order_validity_period: {:?}", diff --git a/crates/shared/src/arguments.rs b/crates/shared/src/arguments.rs index 70d218e108..e8603260ed 100644 --- a/crates/shared/src/arguments.rs +++ b/crates/shared/src/arguments.rs @@ -259,6 +259,36 @@ pub struct Arguments { pub token_quality_cache_prefetch_time: Duration, } +#[derive(Clone, clap::Parser)] +pub struct Db { + /// Base Url of the Postgres database. By default connects to locally + /// running postgres. + #[clap(long, env, default_value = "postgresql://")] + pub db_base_url: Option, + /// Database Username + #[clap(long, env)] + pub db_user: Option, + /// Database password for the given username + #[clap(long, env)] + pub db_password: Option, +} + +impl Db { + /// Returns the DB URL with credentials + /// Returns `None` if the URL is not configured + pub fn to_url(&self) -> Option { + let mut url = self.db_base_url.clone()?; + + if let Some(user) = &self.db_user { + url.query_pairs_mut() + .append_pair("user", user) + .extend_pairs(self.db_password.as_ref().map(|pass| ("password", pass))); + } + + Some(url) + } +} + pub fn display_secret_option( f: &mut Formatter<'_>, name: &str, @@ -282,7 +312,7 @@ pub fn display_option( pub fn display_list( f: &mut Formatter<'_>, name: &str, - iter: impl IntoIterator, + iter: impl IntoIterator, ) -> std::fmt::Result where T: Display, @@ -505,6 +535,59 @@ mod test { assert_eq!(driver, expected); } + #[test] + fn db_url_just_base_url() { + let db = Db { + db_base_url: Some("postgresql://mydatabase:1234".try_into().unwrap()), + db_user: None, + db_password: None, + }; + + assert_eq!( + db.to_url(), + Url::try_from("postgresql://mydatabase:1234").ok() + ); + } + + #[test] + fn db_url_base_url_with_user() { + let db = Db { + db_base_url: Some("postgresql://mydatabase:1234".try_into().unwrap()), + db_user: Some("myuser".to_string()), + db_password: None, + }; + + assert_eq!( + db.to_url(), + Url::try_from("postgresql://mydatabase:1234?user=myuser").ok() + ); + } + + #[test] + fn db_url_base_url_with_user_and_password() { + let db = Db { + db_base_url: Some("postgresql://mydatabase:1234".try_into().unwrap()), + db_user: Some("myuser".to_string()), + db_password: Some("mypassword".to_string()), + }; + + assert_eq!( + db.to_url(), + Url::try_from("postgresql://mydatabase:1234?user=myuser&password=mypassword").ok() + ); + } + + #[test] + fn db_url_empty() { + let db = Db { + db_base_url: None, + db_user: None, + db_password: None, + }; + + assert_eq!(db.to_url(), None); + } + #[test] fn parse_driver_with_threshold() { let argument = "name1|http://localhost:8080|1000000000000000000";