diff --git a/Cargo.lock b/Cargo.lock index 573e31eababd..4c59bfea573b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3571,6 +3571,7 @@ dependencies = [ "connection-string", "either", "futures", + "getrandom 0.2.10", "hex", "indoc 0.3.6", "lru-cache", diff --git a/Cargo.toml b/Cargo.toml index 4a3cd1450caf..b32a1a85cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", + "native", ] [profile.dev.package.backtrace] diff --git a/libs/user-facing-errors/Cargo.toml b/libs/user-facing-errors/Cargo.toml index 9900892209c6..3049a19712b1 100644 --- a/libs/user-facing-errors/Cargo.toml +++ b/libs/user-facing-errors/Cargo.toml @@ -11,7 +11,7 @@ backtrace = "0.3.40" tracing = "0.1" indoc.workspace = true itertools = "0.10" -quaint = { workspace = true, optional = true } +quaint = { path = "../../quaint", optional = true } [features] default = [] diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b699518d0910..52a7edf72aca 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -23,20 +23,28 @@ resolver = "2" features = ["docs", "all"] [features] -default = [] +default = ["mysql", "postgresql", "mssql", "sqlite"] docs = [] # Expose the underlying database drivers when a connector is enabled. This is a # way to access database-specific methods when you need extra control. expose-drivers = [] -all = ["mssql", "mysql", "pooled", "postgresql", "sqlite"] +native = [ + "postgresql-native", + "mysql-native", + "mssql-native", + "sqlite-native", +] + +all = ["native", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql = [ +postgresql-native = [ + "postgresql", "native-tls", "tokio-postgres", "postgres-types", @@ -47,11 +55,24 @@ postgresql = [ "lru-cache", "byteorder", ] +postgresql = [] + +mssql-native = [ + "mssql", + "tiberius", + "tokio-util", + "tokio/time", + "tokio/net", +] +mssql = [] + +mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql = ["chrono/std"] -mssql = ["tiberius", "tokio-util", "tokio/time", "tokio/net", "either"] -mysql = ["mysql_async", "tokio/time", "lru-cache"] pooled = ["mobc"] -sqlite = ["rusqlite", "tokio/sync"] +sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite = [] + fmt-sql = ["sqlformat"] [dependencies] @@ -67,7 +88,7 @@ futures = "0.3" url = "2.1" hex = "0.4" -either = { version = "1.6", optional = true } +either = { version = "1.6" } base64 = { version = "0.12.3" } chrono = { version = "0.4", default-features = false, features = ["serde"] } lru-cache = { version = "0.1", optional = true } @@ -88,7 +109,11 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.0", features = ["macros", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] +version = "0.2" +features = ["js"] [dependencies.byteorder] default-features = false @@ -102,7 +127,7 @@ branch = "vendored-openssl" [dependencies.rusqlite] version = "0.29" -features = ["chrono", "bundled", "column_decltype"] +features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] diff --git a/quaint/README.md b/quaint/README.md index 92033db269b1..03108d9090d3 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,9 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. + - On non-WebAssembly targets, choose `mysql-native` instead. - `postgresql`: Support for PostgreSQL databases. + - On non-WebAssembly targets, choose `postgresql-native` instead. - `sqlite`: Support for SQLite databases. + - On non-WebAssembly targets, choose `sqlite-native` instead. - `mssql`: Support for Microsoft SQL Server databases. + - On non-WebAssembly targets, choose `mssql-native` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index de8bc64d22bb..dddb3c953ad7 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -10,37 +10,49 @@ //! querying interface. mod connection_info; + pub mod metrics; mod queryable; mod result_set; -#[cfg(any(feature = "mssql", feature = "postgresql", feature = "mysql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql")] -pub(crate) mod mssql; -#[cfg(feature = "mysql")] -pub(crate) mod mysql; -#[cfg(feature = "postgresql")] -pub(crate) mod postgres; -#[cfg(feature = "sqlite")] -pub(crate) mod sqlite; - -#[cfg(feature = "mysql")] -pub use self::mysql::*; -#[cfg(feature = "postgresql")] -pub use self::postgres::*; pub use self::result_set::*; pub use connection_info::*; -#[cfg(feature = "mssql")] -pub use mssql::*; pub use queryable::*; -#[cfg(feature = "sqlite")] -pub use sqlite::*; pub use transaction::*; -#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] #[allow(unused_imports)] pub(crate) use type_identifier::*; pub use self::metrics::query; + +#[cfg(feature = "postgresql")] +pub(crate) mod postgres; +#[cfg(feature = "postgresql-native")] +pub use postgres::native::*; +#[cfg(feature = "postgresql")] +pub use postgres::*; + +#[cfg(feature = "mysql")] +pub(crate) mod mysql; +#[cfg(feature = "mysql-native")] +pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::*; + +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite; +#[cfg(feature = "sqlite-native")] +pub use sqlite::native::*; +#[cfg(feature = "sqlite")] +pub use sqlite::*; + +#[cfg(feature = "mssql")] +pub(crate) mod mssql; +#[cfg(feature = "mssql-native")] +pub use mssql::native::*; +#[cfg(feature = "mssql")] +pub use mssql::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..e18b68fb2ce1 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,614 +1,8 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. +pub(crate) mod url; -use super::{IsolationLevel, Transaction, TransactionOptions}; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use connection_string::JdbcString; -use futures::lock::Mutex; -use std::{ - convert::TryFrom, - fmt, - future::Future, - str::FromStr, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tiberius::*; -use tokio::net::TcpStream; -use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +pub use url::*; -/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tiberius; - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -/// TLS mode when connecting to SQL Server. -#[derive(Debug, Clone, Copy)] -pub enum EncryptMode { - /// All traffic is encrypted. - On, - /// Only the login credentials are encrypted. - Off, - /// Nothing is encrypted. - DangerPlainText, -} - -impl fmt::Display for EncryptMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::On => write!(f, "true"), - Self::Off => write!(f, "false"), - Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), - } - } -} - -impl FromStr for EncryptMode { - type Err = Error; - - fn from_str(s: &str) -> crate::Result { - let mode = match s.parse::() { - Ok(true) => Self::On, - _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, - _ => Self::Off, - }; - - Ok(mode) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: EncryptMode, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - schema: String, - trust_server_certificate: bool, - trust_server_certificate_ca: Option, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - transaction_isolation_level: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, -} - -static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; - -#[async_trait] -impl TransactionCapable for Mssql { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> crate::Result> { - // Isolation levels in SQL Server are set on the connection and live until they're changed. - // Always explicitly setting the isolation level each time a tx is started (either to the given value - // or by using the default/connection string value) prevents transactions started on connections from - // the pool to have unexpected isolation levels set. - let isolation = isolation - .or(self.url.query_params.transaction_isolation_level) - .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) - } -} - -impl MssqlUrl { - /// Maximum number of connections the pool can have (if used together with - /// pooled Quaint). - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - /// A duration how long one query can take. - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - /// A duration how long we can try to connect to the database. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - /// A pool check_out timeout. - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout() - } - - /// The isolation level of a transaction. - fn transaction_isolation_level(&self) -> Option { - self.query_params.transaction_isolation_level - } - - /// Name of the database. - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - /// The prefix which to use when querying database. - pub fn schema(&self) -> &str { - self.query_params.schema() - } - - /// Database hostname. - pub fn host(&self) -> &str { - self.query_params.host() - } - - /// The username to use when connecting to the database. - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - /// The password to use when connecting to the database. - pub fn password(&self) -> Option<&str> { - self.query_params.password() - } - - /// The TLS mode to use when connecting to the database. - pub fn encrypt(&self) -> EncryptMode { - self.query_params.encrypt() - } - - /// If true, we allow invalid certificates (self-signed, or otherwise - /// dangerous) when connecting. Should be true only for development and - /// testing. - pub fn trust_server_certificate(&self) -> bool { - self.query_params.trust_server_certificate() - } - - /// Path to a custom server certificate file. - pub fn trust_server_certificate_ca(&self) -> Option<&str> { - self.query_params.trust_server_certificate_ca() - } - - /// Database port. - pub fn port(&self) -> u16 { - self.query_params.port() - } - - /// The JDBC connection string - pub fn connection_string(&self) -> &str { - &self.connection_string - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime() - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime() - } -} - -impl MssqlQueryParams { - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_deref().unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_deref() - } - - fn password(&self) -> Option<&str> { - self.password.as_deref() - } - - fn encrypt(&self) -> EncryptMode { - self.encrypt - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn trust_server_certificate_ca(&self) -> Option<&str> { - self.trust_server_certificate_ca.as_deref() - } - - fn database(&self) -> &str { - &self.database - } - - fn schema(&self) -> &str { - &self.schema - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } - - fn pool_timeout(&self) -> Option { - self.pool_timeout - } - - fn max_connection_lifetime(&self) -> Option { - self.max_connection_lifetime - } - - fn max_idle_connection_lifetime(&self) -> Option { - self.max_idle_connection_lifetime - } -} - -/// A connector interface for the SQL Server database. -#[derive(Debug)] -pub struct Mssql { - client: Mutex>>, - url: MssqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, -} - -impl Mssql { - /// Creates a new connection to SQL Server. - pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_jdbc_string(&url.connection_string)?; - let tcp = TcpStream::connect_named(&config).await?; - let socket_timeout = url.socket_timeout(); - - let connecting = async { - match Client::connect(config, tcp.compat_write()).await { - Ok(client) => Ok(client), - Err(tiberius::error::Error::Routing { host, port }) => { - let mut config = Config::from_jdbc_string(&url.connection_string)?; - config.host(host); - config.port(port); - - let tcp = TcpStream::connect_named(&config).await?; - Client::connect(config, tcp.compat_write()).await - } - Err(e) => Err(e), - } - }; - - let client = super::timeout::connect(url.connect_timeout(), connecting).await?; - - let this = Self { - client: Mutex::new(client), - url, - socket_timeout, - is_healthy: AtomicBool::new(true), - }; - - if let Some(isolation) = this.url.transaction_isolation_level() { - this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) - .await?; - }; - - Ok(this) - } - - /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. - /// This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &Mutex>> { - &self.client - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } -} - -#[async_trait] -impl Queryable for Mssql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { - let mut client = self.client.lock().await; - - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; - - match results.pop() { - Some(rows) => { - let mut columns_set = false; - let mut columns = Vec::new(); - let mut result_rows = Vec::with_capacity(rows.len()); - - for row in rows.into_iter() { - if !columns_set { - columns = row.columns().iter().map(|c| c.name().to_string()).collect(); - columns_set = true; - } - - let mut values: Vec> = Vec::with_capacity(row.len()); - - for val in row.into_iter() { - values.push(Value::try_from(val)?); - } - - result_rows.push(values); - } - - Ok(ResultSet::new(columns, result_rows)) - } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), - } - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut client = self.client.lock().await; - let changes = self.perform_io(query.execute(&mut client)).await?.total(); - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { - let mut client = self.client.lock().await; - self.perform_io(client.simple_query(cmd)).await?.into_results().await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); - - Ok(Self { - connection_string, - query_params, - }) - } - - fn with_jdbc_prefix(input: &str) -> String { - if input.starts_with("jdbc:sqlserver") { - input.into() - } else { - format!("jdbc:{input}") - } - } - - fn parse_query_params(input: &str) -> crate::Result { - let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; - - let host = conn.server_name().map(|server_name| match conn.instance_name() { - Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), - None => server_name.to_string(), - }); - - let port = conn.port(); - let props = conn.properties_mut(); - let user = props.remove("user"); - let password = props.remove("password"); - let database = props.remove("database").unwrap_or_else(|| String::from("master")); - let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); - - let connection_limit = props - .remove("connectionlimit") - .or_else(|| props.remove("connection_limit")) - .map(|param| param.parse()) - .transpose()?; - - let transaction_isolation_level = props - .remove("isolationlevel") - .or_else(|| props.remove("isolation_level")) - .map(|level| { - IsolationLevel::from_str(&level).map_err(|_| { - let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); - Error::builder(kind).build() - }) - }) - .transpose()?; - - let mut connect_timeout = props - .remove("logintimeout") - .or_else(|| props.remove("login_timeout")) - .or_else(|| props.remove("connecttimeout")) - .or_else(|| props.remove("connect_timeout")) - .or_else(|| props.remove("connectiontimeout")) - .or_else(|| props.remove("connection_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match connect_timeout { - None => connect_timeout = Some(Duration::from_secs(5)), - Some(dur) if dur.as_secs() == 0 => connect_timeout = None, - _ => (), - } - - let mut pool_timeout = props - .remove("pooltimeout") - .or_else(|| props.remove("pool_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match pool_timeout { - None => pool_timeout = Some(Duration::from_secs(10)), - Some(dur) if dur.as_secs() == 0 => pool_timeout = None, - _ => (), - } - - let socket_timeout = props - .remove("sockettimeout") - .or_else(|| props.remove("socket_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - let encrypt = props - .remove("encrypt") - .map(|param| EncryptMode::from_str(¶m)) - .transpose()? - .unwrap_or(EncryptMode::On); - - let trust_server_certificate = props - .remove("trustservercertificate") - .or_else(|| props.remove("trust_server_certificate")) - .map(|param| param.parse()) - .transpose()? - .unwrap_or(false); - - let trust_server_certificate_ca: Option = props - .remove("trustservercertificateca") - .or_else(|| props.remove("trust_server_certificate_ca")); - - let mut max_connection_lifetime = props - .remove("max_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_connection_lifetime { - Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, - _ => (), - } - - let mut max_idle_connection_lifetime = props - .remove("max_idle_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_idle_connection_lifetime { - None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), - Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, - _ => (), - } - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - schema, - trust_server_certificate, - trust_server_certificate_ca, - connection_limit, - socket_timeout, - connect_timeout, - pool_timeout, - transaction_isolation_level, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mssql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mssql/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mssql/conversion.rs rename to quaint/src/connector/mssql/native/conversion.rs diff --git a/quaint/src/connector/mssql/error.rs b/quaint/src/connector/mssql/native/error.rs similarity index 100% rename from quaint/src/connector/mssql/error.rs rename to quaint/src/connector/mssql/native/error.rs diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs new file mode 100644 index 000000000000..d22aa7a15dd6 --- /dev/null +++ b/quaint/src/connector/mssql/native/mod.rs @@ -0,0 +1,239 @@ +//! Definitions for the MSSQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mssql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mssql::MssqlUrl; +use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{ + convert::TryFrom, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tiberius::*; +use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; + +/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tiberius; + +static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + // Isolation levels in SQL Server are set on the connection and live until they're changed. + // Always explicitly setting the isolation level each time a tx is started (either to the given value + // or by using the default/connection string value) prevents transactions started on connections from + // the pool to have unexpected isolation levels set. + let isolation = isolation + .or(self.url.query_params.transaction_isolation_level) + .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); + + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} + +/// A connector interface for the SQL Server database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, +} + +impl Mssql { + /// Creates a new connection to SQL Server. + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_jdbc_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let socket_timeout = url.socket_timeout(); + + let connecting = async { + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(tiberius::error::Error::Routing { host, port }) => { + let mut config = Config::from_jdbc_string(&url.connection_string)?; + config.host(host); + config.port(port); + + let tcp = TcpStream::connect_named(&config).await?; + Client::connect(config, tcp.compat_write()).await + } + Err(e) => Err(e), + } + }; + + let client = timeout::connect(url.connect_timeout(), connecting).await?; + + let this = Self { + client: Mutex::new(client), + url, + socket_timeout, + is_healthy: AtomicBool::new(true), + }; + + if let Some(isolation) = this.url.transaction_isolation_level() { + this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) + .await?; + }; + + Ok(this) + } + + /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. + /// This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &Mutex>> { + &self.client + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; + + match results.pop() { + Some(rows) => { + let mut columns_set = false; + let mut columns = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); + + for row in rows.into_iter() { + if !columns_set { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + columns_set = true; + } + + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result_rows.push(values); + } + + Ok(ResultSet::new(columns, result_rows)) + } + None => Ok(ResultSet::new(Vec::new(), Vec::new())), + } + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut client = self.client.lock().await; + let changes = self.perform_io(query.execute(&mut client)).await?.total(); + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.perform_io(client.simple_query(cmd)).await?.into_results().await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mssql/url.rs b/quaint/src/connector/mssql/url.rs new file mode 100644 index 000000000000..42cc0868f9bf --- /dev/null +++ b/quaint/src/connector/mssql/url.rs @@ -0,0 +1,384 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::{ + connector::IsolationLevel, + error::{Error, ErrorKind}, +}; +use connection_string::JdbcString; +use std::{fmt, str::FromStr, time::Duration}; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct MssqlUrl { + pub(crate) connection_string: String, + pub(crate) query_params: MssqlQueryParams, +} + +/// TLS mode when connecting to SQL Server. +#[derive(Debug, Clone, Copy)] +pub enum EncryptMode { + /// All traffic is encrypted. + On, + /// Only the login credentials are encrypted. + Off, + /// Nothing is encrypted. + DangerPlainText, +} + +impl fmt::Display for EncryptMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::On => write!(f, "true"), + Self::Off => write!(f, "false"), + Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), + } + } +} + +impl FromStr for EncryptMode { + type Err = Error; + + fn from_str(s: &str) -> crate::Result { + let mode = match s.parse::() { + Ok(true) => Self::On, + _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, + _ => Self::Off, + }; + + Ok(mode) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + pub(crate) encrypt: EncryptMode, + pub(crate) port: Option, + pub(crate) host: Option, + pub(crate) user: Option, + pub(crate) password: Option, + pub(crate) database: String, + pub(crate) schema: String, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) connection_limit: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) transaction_isolation_level: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, +} + +impl MssqlUrl { + /// Maximum number of connections the pool can have (if used together with + /// pooled Quaint). + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + /// A duration how long one query can take. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + /// A duration how long we can try to connect to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + /// A pool check_out timeout. + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout() + } + + /// The isolation level of a transaction. + pub(crate) fn transaction_isolation_level(&self) -> Option { + self.query_params.transaction_isolation_level + } + + /// Name of the database. + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + /// The prefix which to use when querying database. + pub fn schema(&self) -> &str { + self.query_params.schema() + } + + /// Database hostname. + pub fn host(&self) -> &str { + self.query_params.host() + } + + /// The username to use when connecting to the database. + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + /// The password to use when connecting to the database. + pub fn password(&self) -> Option<&str> { + self.query_params.password() + } + + /// The TLS mode to use when connecting to the database. + pub fn encrypt(&self) -> EncryptMode { + self.query_params.encrypt() + } + + /// If true, we allow invalid certificates (self-signed, or otherwise + /// dangerous) when connecting. Should be true only for development and + /// testing. + pub fn trust_server_certificate(&self) -> bool { + self.query_params.trust_server_certificate() + } + + /// Path to a custom server certificate file. + pub fn trust_server_certificate_ca(&self) -> Option<&str> { + self.query_params.trust_server_certificate_ca() + } + + /// Database port. + pub fn port(&self) -> u16 { + self.query_params.port() + } + + /// The JDBC connection string + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime() + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime() + } +} + +impl MssqlQueryParams { + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_deref().unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_deref() + } + + fn password(&self) -> Option<&str> { + self.password.as_deref() + } + + fn encrypt(&self) -> EncryptMode { + self.encrypt + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn trust_server_certificate_ca(&self) -> Option<&str> { + self.trust_server_certificate_ca.as_deref() + } + + fn database(&self) -> &str { + &self.database + } + + fn schema(&self) -> &str { + &self.schema + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.connect_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } + + fn pool_timeout(&self) -> Option { + self.pool_timeout + } + + fn max_connection_lifetime(&self) -> Option { + self.max_connection_lifetime + } + + fn max_idle_connection_lifetime(&self) -> Option { + self.max_idle_connection_lifetime + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); + + Ok(Self { + connection_string, + query_params, + }) + } + + fn with_jdbc_prefix(input: &str) -> String { + if input.starts_with("jdbc:sqlserver") { + input.into() + } else { + format!("jdbc:{input}") + } + } + + fn parse_query_params(input: &str) -> crate::Result { + let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; + + let host = conn.server_name().map(|server_name| match conn.instance_name() { + Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), + None => server_name.to_string(), + }); + + let port = conn.port(); + let props = conn.properties_mut(); + let user = props.remove("user"); + let password = props.remove("password"); + let database = props.remove("database").unwrap_or_else(|| String::from("master")); + let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); + + let connection_limit = props + .remove("connectionlimit") + .or_else(|| props.remove("connection_limit")) + .map(|param| param.parse()) + .transpose()?; + + let transaction_isolation_level = props + .remove("isolationlevel") + .or_else(|| props.remove("isolation_level")) + .map(|level| { + IsolationLevel::from_str(&level).map_err(|_| { + let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); + Error::builder(kind).build() + }) + }) + .transpose()?; + + let mut connect_timeout = props + .remove("logintimeout") + .or_else(|| props.remove("login_timeout")) + .or_else(|| props.remove("connecttimeout")) + .or_else(|| props.remove("connect_timeout")) + .or_else(|| props.remove("connectiontimeout")) + .or_else(|| props.remove("connection_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match connect_timeout { + None => connect_timeout = Some(Duration::from_secs(5)), + Some(dur) if dur.as_secs() == 0 => connect_timeout = None, + _ => (), + } + + let mut pool_timeout = props + .remove("pooltimeout") + .or_else(|| props.remove("pool_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match pool_timeout { + None => pool_timeout = Some(Duration::from_secs(10)), + Some(dur) if dur.as_secs() == 0 => pool_timeout = None, + _ => (), + } + + let socket_timeout = props + .remove("sockettimeout") + .or_else(|| props.remove("socket_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + let encrypt = props + .remove("encrypt") + .map(|param| EncryptMode::from_str(¶m)) + .transpose()? + .unwrap_or(EncryptMode::On); + + let trust_server_certificate = props + .remove("trustservercertificate") + .or_else(|| props.remove("trust_server_certificate")) + .map(|param| param.parse()) + .transpose()? + .unwrap_or(false); + + let trust_server_certificate_ca: Option = props + .remove("trustservercertificateca") + .or_else(|| props.remove("trust_server_certificate_ca")); + + let mut max_connection_lifetime = props + .remove("max_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_connection_lifetime { + Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, + _ => (), + } + + let mut max_idle_connection_lifetime = props + .remove("max_idle_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_idle_connection_lifetime { + None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), + Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, + _ => (), + } + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + schema, + trust_server_certificate, + trust_server_certificate_ca, + connection_limit, + socket_timeout, + connect_timeout, + pool_timeout, + transaction_isolation_level, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..0dc504dd2d11 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,669 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use lru_cache::LruCache; -use mysql_async::{ - self as my, - prelude::{Query as _, Queryable as _}, -}; -use percent_encoding::percent_decode; -use std::{ - borrow::Cow, - future::Future, - path::{Path, PathBuf}, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio::sync::Mutex; -use url::{Host, Url}; +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. +pub(crate) mod error; +pub(crate) mod url; pub use error::MysqlError; +pub use url::*; -/// The underlying MySQL driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use mysql_async; - -use super::IsolationLevel; - -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - - pub(crate) fn cache(&self) -> LruCache { - LruCache::new(self.query_params.statement_cache_size) - } - - fn parse_query_params(url: &Url) -> Result { - let mut ssl_opts = my::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::default() - .stmt_cache_size(Some(0)) - .user(Some(self.username())) - .pass(self.password()) - .db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config = config.socket(Some(socket)); - } - None => { - config = config.ip_or_hostname(self.host()).tcp_port(self.port()); - } - } - - config = config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - if self.query_params.prefer_socket.is_some() { - config = config.prefer_socket(self.query_params.prefer_socket); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - -impl Mysql { - /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub async fn new(url: MysqlUrl) -> crate::Result { - let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; - - Ok(Self { - socket_timeout: url.query_params.socket_timeout, - conn: Mutex::new(conn), - statement_cache: Mutex::new(url.cache()), - url, - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying mysql_async::Conn. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn conn(&self) -> &Mutex { - &self.conn - } - - async fn perform_io(&self, op: U) -> crate::Result - where - F: Future>, - U: FnOnce() -> F, - { - match super::timeout::socket(self.socket_timeout, op()).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => Ok(res?), - } - } - - async fn prepared(&self, sql: &str, op: U) -> crate::Result - where - F: Future>, - U: Fn(my::Statement) -> F, - { - if self.url.statement_cache_size() == 0 { - self.perform_io(|| async move { - let stmt = { - let mut conn = self.conn.lock().await; - conn.prep(sql).await? - }; - - let res = op(stmt.clone()).await; - - { - let mut conn = self.conn.lock().await; - conn.close(stmt).await?; - } - - res - }) - .await - } else { - self.perform_io(|| async move { - let stmt = self.fetch_cached(sql).await?; - op(stmt).await - }) - .await - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let mut conn = self.conn.lock().await; - if cache.capacity() == cache.len() { - if let Some((_, stmt)) = cache.remove_lru() { - conn.close(stmt).await?; - } - } - - let stmt = conn.prep(sql).await?; - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } -} - -impl_default_TransactionCapable!(Mysql); - -#[async_trait] -impl Queryable for Mysql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); - - let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); - - for mut row in rows { - result_set.rows.push(row.take_result_row()?); - } - - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; - - Ok(result_set) - }) - .await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - conn.exec_drop(stmt, conversion::conv_params(params)?).await?; - - Ok(conn.affected_rows()) - }) - .await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - self.perform_io(|| async move { - let mut conn = self.conn.lock().await; - let mut result = cmd.run(&mut *conn).await?; - - loop { - result.map(drop).await?; - - if result.is_empty() { - result.map(drop).await?; - break; - } - } - - Ok(()) - }) - .await - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@GLOBAL.version version"#; - let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mysql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/error.rs index dd7c3d3bfa66..7b4813bf0223 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/error.rs @@ -1,22 +1,23 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; +use thiserror::Error; + +// This is a partial copy of the `mysql_async::Error` using only the enum variant used by Prisma. +// This avoids pulling in `mysql_async`, which would break Wasm compilation. +#[derive(Debug, Error)] +enum MysqlAsyncError { + #[error("Server error: `{}'", _0)] + Server(#[source] MysqlError), +} +/// This type represents MySql server error. +#[derive(Debug, Error, Clone, Eq, PartialEq)] +#[error("ERROR {} ({}): {}", state, code, message)] pub struct MysqlError { pub code: u16, pub message: String, pub state: String, } -impl From<&my::ServerError> for MysqlError { - fn from(value: &my::ServerError) -> Self { - MysqlError { - code: value.code, - message: value.message.to_owned(), - state: value.state.to_owned(), - } - } -} - impl From for Error { fn from(error: MysqlError) -> Self { let code = error.code; @@ -232,7 +233,7 @@ impl From for Error { } _ => { let kind = ErrorKind::QueryError( - my::Error::Server(my::ServerError { + MysqlAsyncError::Server(MysqlError { message: error.message.clone(), code, state: error.state.clone(), @@ -249,24 +250,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: my::Error) -> Error { - match e { - my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { - message: err.to_string(), - }) - .build(), - my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - Error::builder(ErrorKind::ConnectionClosed).build() - } - my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::Error::Server(ref server_error) => { - let mysql_error: MysqlError = server_error.into(); - mysql_error.into() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} diff --git a/quaint/src/connector/mysql/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mysql/conversion.rs rename to quaint/src/connector/mysql/native/conversion.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs new file mode 100644 index 000000000000..89c21fb706f6 --- /dev/null +++ b/quaint/src/connector/mysql/native/error.rs @@ -0,0 +1,36 @@ +use crate::{ + connector::mysql::error::MysqlError, + error::{Error, ErrorKind}, +}; +use mysql_async as my; + +impl From<&my::ServerError> for MysqlError { + fn from(value: &my::ServerError) -> Self { + MysqlError { + code: value.code, + message: value.message.to_owned(), + state: value.state.to_owned(), + } + } +} + +impl From for Error { + fn from(e: my::Error) -> Error { + match e { + my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { + message: err.to_string(), + }) + .build(), + my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + Error::builder(ErrorKind::ConnectionClosed).build() + } + my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), + my::Error::Server(ref server_error) => { + let mysql_error: MysqlError = server_error.into(); + mysql_error.into() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs new file mode 100644 index 000000000000..fdcc3a6276d1 --- /dev/null +++ b/quaint/src/connector/mysql/native/mod.rs @@ -0,0 +1,297 @@ +//! Definitions for the MySQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mysql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::mysql::MysqlUrl; +use crate::connector::{timeout, IsolationLevel}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use lru_cache::LruCache; +use mysql_async::{ + self as my, + prelude::{Query as _, Queryable as _}, +}; +use std::{ + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::sync::Mutex; + +/// The underlying MySQL driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use mysql_async; + +impl MysqlUrl { + pub(crate) fn cache(&self) -> LruCache { + LruCache::new(self.query_params.statement_cache_size) + } + + pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { + let mut config = my::OptsBuilder::default() + .stmt_cache_size(Some(0)) + .user(Some(self.username())) + .pass(self.password()) + .db_name(Some(self.dbname())); + + match self.socket() { + Some(ref socket) => { + config = config.socket(Some(socket)); + } + None => { + config = config.ip_or_hostname(self.host()).tcp_port(self.port()); + } + } + + config = config.conn_ttl(Some(Duration::from_secs(5))); + + if self.query_params.use_ssl { + config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); + } + + if self.query_params.prefer_socket.is_some() { + config = config.prefer_socket(self.query_params.prefer_socket); + } + + config + } +} + +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + +impl Mysql { + /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. + pub async fn new(url: MysqlUrl) -> crate::Result { + let conn = timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; + + Ok(Self { + socket_timeout: url.query_params.socket_timeout, + conn: Mutex::new(conn), + statement_cache: Mutex::new(url.cache()), + url, + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying mysql_async::Conn. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn conn(&self) -> &Mutex { + &self.conn + } + + async fn perform_io(&self, op: U) -> crate::Result + where + F: Future>, + U: FnOnce() -> F, + { + match timeout::socket(self.socket_timeout, op()).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => Ok(res?), + } + } + + async fn prepared(&self, sql: &str, op: U) -> crate::Result + where + F: Future>, + U: Fn(my::Statement) -> F, + { + if self.url.statement_cache_size() == 0 { + self.perform_io(|| async move { + let stmt = { + let mut conn = self.conn.lock().await; + conn.prep(sql).await? + }; + + let res = op(stmt.clone()).await; + + { + let mut conn = self.conn.lock().await; + conn.close(stmt).await?; + } + + res + }) + .await + } else { + self.perform_io(|| async move { + let stmt = self.fetch_cached(sql).await?; + op(stmt).await + }) + .await + } + } + + async fn fetch_cached(&self, sql: &str) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let mut conn = self.conn.lock().await; + if cache.capacity() == cache.len() { + if let Some((_, stmt)) = cache.remove_lru() { + conn.close(stmt).await?; + } + } + + let stmt = conn.prep(sql).await?; + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } +} + +impl_default_TransactionCapable!(Mysql); + +#[async_trait] +impl Queryable for Mysql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.query_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; + let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); + + let last_id = conn.last_insert_id(); + let mut result_set = ResultSet::new(columns, Vec::new()); + + for mut row in rows { + result_set.rows.push(row.take_result_row()?); + } + + if let Some(id) = last_id { + result_set.set_last_insert_id(id); + }; + + Ok(result_set) + }) + .await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.execute_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + conn.exec_drop(stmt, conversion::conv_params(params)?).await?; + + Ok(conn.affected_rows()) + }) + .await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + self.perform_io(|| async move { + let mut conn = self.conn.lock().await; + let mut result = cmd.run(&mut *conn).await?; + + loop { + result.map(drop).await?; + + if result.is_empty() { + result.map(drop).await?; + break; + } + } + + Ok(()) + }) + .await + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@GLOBAL.version version"#; + let rows = timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + true + } +} diff --git a/quaint/src/connector/mysql/url.rs b/quaint/src/connector/mysql/url.rs new file mode 100644 index 000000000000..f0756fa95833 --- /dev/null +++ b/quaint/src/connector/mysql/url.rs @@ -0,0 +1,401 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(crate) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-native")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-native")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-native")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-native")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-native")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..d1694108a1b7 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,1593 +1,10 @@ -mod conversion; -mod error; - -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::{Borrow, Cow}, - fmt::{Debug, Display}, - fs, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio_postgres::{ - config::{ChannelBinding, SslMode}, - Client, Config, Statement, -}; -use url::{Host, Url}; +//! Wasm-compatible definitions for the PostgreSQL connector. +//! This module is only available with the `postgresql` feature. +pub(crate) mod error; +pub(crate) mod url; pub use error::PostgresError; +pub use url::*; -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -/// The underlying postgres driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tokio_postgres; - -use super::{IsolationLevel, Transaction}; - -#[derive(Clone)] -struct Hidden(T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - -struct PostgresClient(Client); - -impl Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("PostgresClient") - } -} - -/// A connector interface for the PostgreSQL database. -#[derive(Debug)] -pub struct PostgreSql { - client: PostgresClient, - pg_bouncer: bool, - socket_timeout: Option, - statement_cache: Mutex>, - is_healthy: AtomicBool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({err})"), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({err})"), - }) - .build() - })?; - let password = self.identity_password.0.as_deref().unwrap_or(""); - let identity = Identity::from_pkcs12(&db, password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, - flavour: PostgresFlavour, -} - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut application_name = None; - let mut channel_binding = ChannelBinding::Prefer; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - channel_binding, - options, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - fn set_search_path(&self, config: &mut Config) { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if self.query_params.pg_bouncer { - return; - } - - if let Some(schema) = &self.query_params.schema { - if self.flavour().is_cockroach() && is_safe_identifier(schema) { - config.search_path(CockroachSearchPath(schema).to_string()); - } - - if self.flavour().is_postgres() { - config.search_path(PostgresSearchPath(schema).to_string()); - } - } - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(options) = self.options() { - config.options(options); - } - - if let Some(application_name) = self.application_name() { - config.application_name(application_name); - } - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - } - - self.set_search_path(&mut config); - - config.ssl_mode(self.query_params.ssl_mode); - - config.channel_binding(self.query_params.channel_binding); - - config - } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: Option, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - statement_cache_size: usize, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - application_name: Option, - channel_binding: ChannelBinding, - options: Option, -} - -impl PostgreSql { - /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { - let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); - - // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. - if let Some(schema) = &url.query_params.schema { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if url.query_params.pg_bouncer - || url.flavour().is_unknown() - || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) - { - let session_variables = format!( - r##"{set_search_path}"##, - set_search_path = SetSearchPath(url.query_params.schema.as_deref()) - ); - - client.simple_query(session_variables.as_str()).await?; - } - } - - Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying tokio_postgres::Client. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &tokio_postgres::Client { - &self.client.0 - } - - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } - - fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { - if params.len() > i16::MAX as usize { - // tokio_postgres would return an error here. Let's avoid calling the driver - // and return an error early. - let kind = ErrorKind::QueryInvalidInput(format!( - "too many bind variables in prepared statement, expected maximum of {}, received {}", - i16::MAX, - params.len() - )); - Err(Error::builder(kind).build()) - } else { - Ok(()) - } - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -impl_default_TransactionCapable!(PostgreSql); - -#[async_trait] -impl Queryable for PostgreSql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.query_raw(sql.as_str(), ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.execute_raw(sql.as_str(), ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.perform_io(self.client.0.simple_query(cmd)).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { - if self.pg_bouncer { - tx.raw_cmd("DEALLOCATE ALL").await - } else { - Ok(()) - } - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::test_api::postgres::CONN_STR; - use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - - #[tokio::test] - async fn test_custom_search_path_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_pg_pgbouncer() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - url.query_pairs_mut().append_pair("pbbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - - #[test] - fn test_safe_ident() { - // Safe - assert!(is_safe_identifier("hello")); - assert!(is_safe_identifier("_hello")); - assert!(is_safe_identifier("àbracadabra")); - assert!(is_safe_identifier("h3ll0")); - assert!(is_safe_identifier("héllo")); - assert!(is_safe_identifier("héll0$")); - assert!(is_safe_identifier("héll_0$")); - assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); - - // Not safe - assert!(!is_safe_identifier("")); - assert!(!is_safe_identifier("Hello")); - assert!(!is_safe_identifier("hEllo")); - assert!(!is_safe_identifier("$hello")); - assert!(!is_safe_identifier("hello!")); - assert!(!is_safe_identifier("hello#")); - assert!(!is_safe_identifier("he llo")); - assert!(!is_safe_identifier(" hello")); - assert!(!is_safe_identifier("he-llo")); - assert!(!is_safe_identifier("hÉllo")); - assert!(!is_safe_identifier("1337")); - assert!(!is_safe_identifier("_HELLO")); - assert!(!is_safe_identifier("HELLO")); - assert!(!is_safe_identifier("HELLO$")); - assert!(!is_safe_identifier("ÀBRACADABRA")); - - for ident in RESERVED_KEYWORDS { - assert!(!is_safe_identifier(ident)); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert!(!is_safe_identifier(ident)); - } - } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } -} +#[cfg(feature = "postgresql-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/error.rs index d4e5ec7837fe..ab6ec7b07847 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/error.rs @@ -1,7 +1,5 @@ use std::fmt::{Display, Formatter}; -use tokio_postgres::error::DbError; - use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; #[derive(Debug)] @@ -17,7 +15,7 @@ pub struct PostgresError { impl std::error::Error for PostgresError {} impl Display for PostgresError { - // copy of DbError::fmt + // copy of tokio_postgres::error::DbError::fmt fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { @@ -30,19 +28,6 @@ impl Display for PostgresError { } } -impl From<&DbError> for PostgresError { - fn from(value: &DbError) -> Self { - PostgresError { - code: value.code().code().to_string(), - severity: value.severity().to_string(), - message: value.message().to_string(), - detail: value.detail().map(ToString::to_string), - column: value.column().map(ToString::to_string), - hint: value.hint().map(ToString::to_string), - } - } -} - impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -245,110 +230,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - if e.is_closed() { - return Error::builder(ErrorKind::ConnectionClosed).build(); - } - - if let Some(db_error) = e.as_db_error() { - return PostgresError::from(db_error).into(); - } - - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } - - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - if let Some(uuid_error) = try_extracting_uuid_error(&e) { - return uuid_error; - } - - let reason = format!("{e}"); - let code = e.code().map(|c| c.code()); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } - } -} - -fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::UUIDError(format!("{err}"))) - .map(|kind| Error::builder(kind).build()) -} - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{e}"), - }; - - Error::builder(kind).build() - } -} diff --git a/quaint/src/connector/postgres/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs similarity index 100% rename from quaint/src/connector/postgres/conversion.rs rename to quaint/src/connector/postgres/native/conversion.rs diff --git a/quaint/src/connector/postgres/conversion/decimal.rs b/quaint/src/connector/postgres/native/conversion/decimal.rs similarity index 100% rename from quaint/src/connector/postgres/conversion/decimal.rs rename to quaint/src/connector/postgres/native/conversion/decimal.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs new file mode 100644 index 000000000000..c353e397705c --- /dev/null +++ b/quaint/src/connector/postgres/native/error.rs @@ -0,0 +1,126 @@ +use tokio_postgres::error::DbError; + +use crate::{ + connector::postgres::error::PostgresError, + error::{Error, ErrorKind}, +}; + +impl From<&DbError> for PostgresError { + fn from(value: &DbError) -> Self { + PostgresError { + code: value.code().code().to_string(), + severity: value.severity().to_string(), + message: value.message().to_string(), + detail: value.detail().map(ToString::to_string), + column: value.column().map(ToString::to_string), + hint: value.hint().map(ToString::to_string), + } + } +} + +impl From for Error { + fn from(e: tokio_postgres::error::Error) -> Error { + if e.is_closed() { + return Error::builder(ErrorKind::ConnectionClosed).build(); + } + + if let Some(db_error) = e.as_db_error() { + return PostgresError::from(db_error).into(); + } + + if let Some(tls_error) = try_extracting_tls_error(&e) { + return tls_error; + } + + // Same for IO errors. + if let Some(io_error) = try_extracting_io_error(&e) { + return io_error; + } + + if let Some(uuid_error) = try_extracting_uuid_error(&e) { + return uuid_error; + } + + let reason = format!("{e}"); + let code = e.code().map(|c| c.code()); + + match reason.as_str() { + "error connecting to server: timed out" => { + let mut builder = Error::builder(ErrorKind::ConnectTimeout); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // sigh... + // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 + "error performing TLS handshake: server does not support TLS" => { + let mut builder = Error::builder(ErrorKind::TlsError { + message: reason.clone(), + }); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // double sigh + _ => { + let code = code.map(|c| c.to_string()); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } + } + } +} + +fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::UUIDError(format!("{err}"))) + .map(|kind| Error::builder(kind).build()) +} + +fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| err.into()) +} + +fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) + .map(|kind| Error::builder(kind).build()) +} + +impl From for Error { + fn from(e: native_tls::Error) -> Error { + Error::from(&e) + } +} + +impl From<&native_tls::Error> for Error { + fn from(e: &native_tls::Error) -> Error { + let kind = ErrorKind::TlsError { + message: format!("{e}"), + }; + + Error::builder(kind).build() + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs new file mode 100644 index 000000000000..30f34e7002be --- /dev/null +++ b/quaint/src/connector/postgres/native/mod.rs @@ -0,0 +1,972 @@ +//! Definitions for the Postgres connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `postgresql-native` feature. +mod conversion; +mod error; + +pub(crate) use crate::connector::postgres::url::PostgresUrl; +use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; +use crate::connector::{timeout, IsolationLevel, Transaction}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::{future::FutureExt, lock::Mutex}; +use lru_cache::LruCache; +use native_tls::{Certificate, Identity, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + fs, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; + +/// The underlying postgres driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tokio_postgres; + +struct PostgresClient(Client); + +impl Debug for PostgresClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PostgresClient") + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct PostgreSql { + client: PostgresClient, + pg_bouncer: bool, + socket_timeout: Option, + statement_cache: Mutex>, + is_healthy: AtomicBool, +} + +#[derive(Debug)] +struct SslAuth { + certificate: Hidden>, + identity: Hidden>, + ssl_accept_mode: SslAcceptMode, +} + +impl Default for SslAuth { + fn default() -> Self { + Self { + certificate: Hidden(None), + identity: Hidden(None), + ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, + } + } +} + +impl SslAuth { + fn certificate(&mut self, certificate: Certificate) -> &mut Self { + self.certificate = Hidden(Some(certificate)); + self + } + + fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Hidden(Some(identity)); + self + } + + fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { + self.ssl_accept_mode = mode; + self + } +} + +impl SslParams { + async fn into_auth(self) -> crate::Result { + let mut auth = SslAuth::default(); + auth.accept_mode(self.ssl_accept_mode); + + if let Some(ref cert_file) = self.certificate_file { + let cert = fs::read(cert_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("cert file not found ({err})"), + }) + .build() + })?; + + auth.certificate(Certificate::from_pem(&cert)?); + } + + if let Some(ref identity_file) = self.identity_file { + let db = fs::read(identity_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("identity file not found ({err})"), + }) + .build() + })?; + let password = self.identity_password.0.as_deref().unwrap_or(""); + let identity = Identity::from_pkcs12(&db, password)?; + + auth.identity(identity); + } + + Ok(auth) + } +} + +impl PostgresUrl { + pub(crate) fn cache(&self) -> LruCache { + if self.query_params.pg_bouncer { + LruCache::new(0) + } else { + LruCache::new(self.query_params.statement_cache_size) + } + } + + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding + } + + /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + fn set_search_path(&self, config: &mut Config) { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if self.query_params.pg_bouncer { + return; + } + + if let Some(schema) = &self.query_params.schema { + if self.flavour().is_cockroach() && is_safe_identifier(schema) { + config.search_path(CockroachSearchPath(schema).to_string()); + } + + if self.flavour().is_postgres() { + config.search_path(PostgresSearchPath(schema).to_string()); + } + } + } + + pub(crate) fn to_config(&self) -> Config { + let mut config = Config::new(); + + config.user(self.username().borrow()); + config.password(self.password().borrow() as &str); + config.host(self.host()); + config.port(self.port()); + config.dbname(self.dbname()); + config.pgbouncer_mode(self.query_params.pg_bouncer); + + if let Some(options) = self.options() { + config.options(options); + } + + if let Some(application_name) = self.application_name() { + config.application_name(application_name); + } + + if let Some(connect_timeout) = self.query_params.connect_timeout { + config.connect_timeout(connect_timeout); + } + + self.set_search_path(&mut config); + + config.ssl_mode(self.query_params.ssl_mode); + + config.channel_binding(self.query_params.channel_binding); + + config + } +} + +impl PostgreSql { + /// Create a new connection to the database. + pub async fn new(url: PostgresUrl) -> crate::Result { + let config = url.to_config(); + + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls = MakeTlsConnector::new(tls_builder.build()?); + let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + + tokio::spawn(conn.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + + // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. + if let Some(schema) = &url.query_params.schema { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if url.query_params.pg_bouncer + || url.flavour().is_unknown() + || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) + { + let session_variables = format!( + r##"{set_search_path}"##, + set_search_path = SetSearchPath(url.query_params.schema.as_deref()) + ); + + client.simple_query(session_variables.as_str()).await?; + } + } + + Ok(Self { + client: PostgresClient(client), + socket_timeout: url.query_params.socket_timeout, + pg_bouncer: url.query_params.pg_bouncer, + statement_cache: Mutex::new(url.cache()), + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying tokio_postgres::Client. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &tokio_postgres::Client { + &self.client.0 + } + + async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let param_types = conversion::params_to_types(params); + let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; + + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } + + fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { + if params.len() > i16::MAX as usize { + // tokio_postgres would return an error here. Let's avoid calling the driver + // and return an error early. + let kind = ErrorKind::QueryInvalidInput(format!( + "too many bind variables in prepared statement, expected maximum of {}, received {}", + i16::MAX, + params.len() + )); + Err(Error::builder(kind).build()) + } else { + Ok(()) + } + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +impl_default_TransactionCapable!(PostgreSql); + +#[async_trait] +impl Queryable for PostgreSql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.query_raw(sql.as_str(), ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.execute_raw(sql.as_str(), ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + self.perform_io(self.client.0.simple_query(cmd)).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT version()"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { + if self.pg_bouncer { + tx.raw_cmd("DEALLOCATE ALL").await + } else { + Ok(()) + } + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::connector::Queryable; + use crate::tests::test_api::postgres::CONN_STR; + use crate::tests::test_api::CRDB_CONN_STR; + use url::Url; + + #[tokio::test] + async fn test_custom_search_path_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_pg_pgbouncer() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + url.query_pairs_mut().append_pair("pbbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[test] + fn test_safe_ident() { + // Safe + assert!(is_safe_identifier("hello")); + assert!(is_safe_identifier("_hello")); + assert!(is_safe_identifier("àbracadabra")); + assert!(is_safe_identifier("h3ll0")); + assert!(is_safe_identifier("héllo")); + assert!(is_safe_identifier("héll0$")); + assert!(is_safe_identifier("héll_0$")); + assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); + + // Not safe + assert!(!is_safe_identifier("")); + assert!(!is_safe_identifier("Hello")); + assert!(!is_safe_identifier("hEllo")); + assert!(!is_safe_identifier("$hello")); + assert!(!is_safe_identifier("hello!")); + assert!(!is_safe_identifier("hello#")); + assert!(!is_safe_identifier("he llo")); + assert!(!is_safe_identifier(" hello")); + assert!(!is_safe_identifier("he-llo")); + assert!(!is_safe_identifier("hÉllo")); + assert!(!is_safe_identifier("1337")); + assert!(!is_safe_identifier("_HELLO")); + assert!(!is_safe_identifier("HELLO")); + assert!(!is_safe_identifier("HELLO$")); + assert!(!is_safe_identifier("ÀBRACADABRA")); + + for ident in RESERVED_KEYWORDS { + assert!(!is_safe_identifier(ident)); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert!(!is_safe_identifier(ident)); + } + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs new file mode 100644 index 000000000000..f0b60d88a848 --- /dev/null +++ b/quaint/src/connector/postgres/url.rs @@ -0,0 +1,695 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-native")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(crate) certificate_file: Option, + pub(crate) identity_file: Option, + pub(crate) identity_password: Hidden>, + pub(crate) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(crate) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(crate) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(crate) url: Url, + pub(crate) query_params: PostgresUrlQueryParams, + pub(crate) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-native")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-native")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-native")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-native")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-native")] + channel_binding, + #[cfg(feature = "postgresql-native")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-native")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-native")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Value; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::tests::test_api::postgres::CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..c59c947b8dc1 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,353 +1,11 @@ -mod conversion; -mod error; +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. +pub(crate) mod error; +mod ffi; +pub(crate) mod params; pub use error::SqliteError; +pub use params::*; -pub use rusqlite::{params_from_iter, version as sqlite_version}; - -use super::IsolationLevel; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use std::{convert::TryFrom, path::Path, time::Duration}; -use tokio::sync::Mutex; - -pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; - -/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use rusqlite; - -/// A connector interface for the SQLite database -pub struct Sqlite { - pub(crate) client: Mutex, -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, - pub max_connection_lifetime: Option, - pub max_idle_connection_lifetime: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut socket_timeout = None; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), - socket_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } - } -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; - let file_path = params.file_path; - - let conn = rusqlite::Connection::open(file_path.as_str())?; - - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; - - let client = Mutex::new(conn); - - Ok(Sqlite { client }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) - } - - /// Open a new SQLite database in memory. - pub fn new_in_memory() -> crate::Result { - let client = rusqlite::Connection::open_in_memory()?; - - Ok(Sqlite { - client: Mutex::new(client), - }) - } - - /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo - /// feature. This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn connection(&self) -> &Mutex { - &self.client - } -} - -impl_default_TransactionCapable!(Sqlite); - -#[async_trait] -impl Queryable for Sqlite { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - - let mut stmt = client.prepare_cached(sql)?; - - let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); - - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); - } - - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; - - Ok(res) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) - } - - fn is_healthy(&self) -> bool { - true - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - // SQLite is always "serializable", other modes involve pragmas - // and shared cache mode, which is out of scope for now and should be implemented - // as part of a separate effort. - if !matches!(isolation_level, IsolationLevel::Serializable) { - let kind = ErrorKind::invalid_isolation_level(&isolation_level); - return Err(Error::builder(kind).build()); - } - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ast::*, - connector::Queryable, - error::{ErrorKind, Name}, - }; - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[tokio::test] - async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); - let select = Select::from_table("not_there"); - - let err = conn.select(select).await.unwrap_err(); - - match err.kind() { - ErrorKind::TableDoesNotExist { table } => { - assert_eq!(&Name::available("not_there"), table); - } - e => panic!("Expected error TableDoesNotExist, got {:?}", e), - } - } - - #[tokio::test] - async fn in_memory_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); - - // Check that we do get a separate, new database. - let other_conn = Sqlite::new_in_memory().unwrap(); - - let err = other_conn.select(select).await.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); - } - - #[tokio::test] - async fn quoting_in_returning_in_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - let insert: Insert = Insert::from(insert).returning(["txt space"]); - - let result = conn.insert(insert).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - } -} +#[cfg(feature = "sqlite-native")] +pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/error.rs b/quaint/src/connector/sqlite/error.rs index c10b335cb3c0..2c6ff11350fd 100644 --- a/quaint/src/connector/sqlite/error.rs +++ b/quaint/src/connector/sqlite/error.rs @@ -1,8 +1,4 @@ -use std::fmt; - use crate::error::*; -use rusqlite::ffi; -use rusqlite::types::FromSqlError; #[derive(Debug)] pub struct SqliteError { @@ -10,14 +6,10 @@ pub struct SqliteError { pub message: Option, } -impl fmt::Display for SqliteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Error code {}: {}", - self.extended_code, - ffi::code_to_str(self.extended_code) - ) +#[cfg(not(feature = "sqlite-native"))] +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error code {}", self.extended_code) } } @@ -37,7 +29,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: super::ffi::SQLITE_CONSTRAINT_UNIQUE | super::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -58,7 +50,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: super::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -79,7 +71,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_FOREIGNKEY | ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: super::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | super::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -92,7 +84,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == super::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -152,55 +144,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error)); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure(ffi::Error { code: _, extended_code }, message) => { - SqliteError::new(extended_code, message).into() - } - - rusqlite::Error::SqlInputError { - error: ffi::Error { extended_code, .. }, - msg, - .. - } => SqliteError::new(extended_code, Some(msg)).into(), - - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} - -impl From for Error { - fn from(e: FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() - } -} diff --git a/quaint/src/connector/sqlite/ffi.rs b/quaint/src/connector/sqlite/ffi.rs new file mode 100644 index 000000000000..c510a459be81 --- /dev/null +++ b/quaint/src/connector/sqlite/ffi.rs @@ -0,0 +1,8 @@ +//! Here, we export only the constants we need to avoid pulling in `rusqlite::ffi::*`, in the sibling `error.rs` file, +//! which would break Wasm compilation. +pub const SQLITE_BUSY: i32 = 5; +pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; +pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; +pub const SQLITE_CONSTRAINT_PRIMARYKEY: i32 = 1555; +pub const SQLITE_CONSTRAINT_TRIGGER: i32 = 1811; +pub const SQLITE_CONSTRAINT_UNIQUE: i32 = 2067; diff --git a/quaint/src/connector/sqlite/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs similarity index 100% rename from quaint/src/connector/sqlite/conversion.rs rename to quaint/src/connector/sqlite/native/conversion.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs new file mode 100644 index 000000000000..51b2417ed821 --- /dev/null +++ b/quaint/src/connector/sqlite/native/error.rs @@ -0,0 +1,66 @@ +use crate::connector::sqlite::error::SqliteError; + +use crate::error::*; + +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error code {}: {}", + self.extended_code, + rusqlite::ffi::code_to_str(self.extended_code) + ) + } +} + +impl From for Error { + fn from(e: rusqlite::Error) -> Error { + match e { + rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { + Ok(error) => *error, + Err(error) => { + let mut builder = Error::builder(ErrorKind::QueryError(error)); + + builder.set_original_message("Could not interpret parameters in an SQLite query."); + + builder.build() + } + }, + rusqlite::Error::InvalidQuery => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + builder.set_original_message( + "Could not interpret the query or its parameters. Check the syntax and parameter types.", + ); + + builder.build() + } + rusqlite::Error::ExecuteReturnedResults => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + builder.set_original_message("Execute returned results, which is not allowed in SQLite."); + + builder.build() + } + + rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), + + rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code: _, extended_code }, message) => { + SqliteError::new(extended_code, message).into() + } + + rusqlite::Error::SqlInputError { + error: rusqlite::ffi::Error { extended_code, .. }, + msg, + .. + } => SqliteError::new(extended_code, Some(msg)).into(), + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} + +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { + Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs new file mode 100644 index 000000000000..3bf0c46a7db5 --- /dev/null +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -0,0 +1,234 @@ +//! Definitions for the SQLite connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `sqlite-native` feature. +mod conversion; +mod error; + +use crate::connector::sqlite::params::SqliteParams; +use crate::connector::IsolationLevel; + +pub use rusqlite::{params_from_iter, version as sqlite_version}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use std::convert::TryFrom; +use tokio::sync::Mutex; + +/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use rusqlite; + +/// A connector interface for the SQLite database +pub struct Sqlite { + pub(crate) client: Mutex, +} + +impl TryFrom<&str> for Sqlite { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let params = SqliteParams::try_from(path)?; + let file_path = params.file_path; + + let conn = rusqlite::Connection::open(file_path.as_str())?; + + if let Some(timeout) = params.socket_timeout { + conn.busy_timeout(timeout)?; + }; + + let client = Mutex::new(conn); + + Ok(Sqlite { client }) + } +} + +impl Sqlite { + pub fn new(file_path: &str) -> crate::Result { + Self::try_from(file_path) + } + + /// Open a new SQLite database in memory. + pub fn new_in_memory() -> crate::Result { + let client = rusqlite::Connection::open_in_memory()?; + + Ok(Sqlite { + client: Mutex::new(client), + }) + } + + /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo + /// feature. This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn connection(&self) -> &Mutex { + &self.client + } +} + +impl_default_TransactionCapable!(Sqlite); + +#[async_trait] +impl Queryable for Sqlite { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + + let mut stmt = client.prepare_cached(sql)?; + + let mut rows = stmt.query(params_from_iter(params.iter()))?; + let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + + while let Some(row) = rows.next()? { + result.rows.push(row.get_result_row()?); + } + + result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + let mut stmt = client.prepare_cached(sql)?; + let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; + + Ok(res) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + let client = self.client.lock().await; + client.execute_batch(cmd)?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + Ok(Some(rusqlite::version().into())) + } + + fn is_healthy(&self) -> bool { + true + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + // SQLite is always "serializable", other modes involve pragmas + // and shared cache mode, which is out of scope for now and should be implemented + // as part of a separate effort. + if !matches!(isolation_level, IsolationLevel::Serializable) { + let kind = ErrorKind::invalid_isolation_level(&isolation_level); + return Err(Error::builder(kind).build()); + } + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::*, + connector::Queryable, + error::{ErrorKind, Name}, + }; + + #[tokio::test] + async fn unknown_table_should_give_a_good_error() { + let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let select = Select::from_table("not_there"); + + let err = conn.select(select).await.unwrap_err(); + + match err.kind() { + ErrorKind::TableDoesNotExist { table } => { + assert_eq!(&Name::available("not_there"), table); + } + e => panic!("Expected error TableDoesNotExist, got {:?}", e), + } + } + + #[tokio::test] + async fn in_memory_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); + + // Check that we do get a separate, new database. + let other_conn = Sqlite::new_in_memory().unwrap(); + + let err = other_conn.select(select).await.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); + } + + #[tokio::test] + async fn quoting_in_returning_in_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + let insert: Insert = Insert::from(insert).returning(["txt space"]); + + let result = conn.insert(insert).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + } +} diff --git a/quaint/src/connector/sqlite/params.rs b/quaint/src/connector/sqlite/params.rs new file mode 100644 index 000000000000..f024aa97a694 --- /dev/null +++ b/quaint/src/connector/sqlite/params.rs @@ -0,0 +1,131 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug)] +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub max_connection_lifetime: Option, + pub max_idle_connection_lifetime: Option, +} + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let mut connection_limit = None; + let mut socket_timeout = None; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = None; + + if path_parts.len() > 1 { + let params = path_parts.last().unwrap().split('&').map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (splitted[0], splitted[1]) + }); + + for (k, v) in params { + match k { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = k); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), + socket_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } +} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..a77513876726 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..73441b7609ba 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql")] +#[cfg(feature = "postgresql-native")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..1a4dbdf52a61 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -1,7 +1,5 @@ //! A single connection abstraction to a SQL database. -#[cfg(feature = "sqlite")] -use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, @@ -9,7 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite")] +#[cfg(feature = "sqlite-native")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -127,30 +125,31 @@ impl Quaint { /// - `isolationLevel` the transaction isolation level. Possible values: /// `READ UNCOMMITTED`, `READ COMMITTED`, `REPEATABLE READ`, `SNAPSHOT`, /// `SERIALIZABLE`. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-native")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-native")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -166,9 +165,11 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-native")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { + use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; + Ok(Quaint { inner: Arc::new(connector::Sqlite::new_in_memory()?), connection_info: Arc::new(ConnectionInfo::InMemorySqlite { diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index d16771aa3daf..788b8ca65576 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -14,6 +14,6 @@ prisma-value = {path = "../../../libs/prisma-value"} serde.workspace = true serde_json.workspace = true thiserror = "1.0" -user-facing-errors = {path = "../../../libs/user-facing-errors"} +user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid = "1" indexmap = "1.7"