From c54730ba0844860a626c61d52c4d7e7642e7f320 Mon Sep 17 00:00:00 2001 From: Sergey Tatarintsev Date: Thu, 8 Feb 2024 15:01:19 +0100 Subject: [PATCH] qe: Remove more per-connector code after splitting --- quaint/src/ast/delete.rs | 2 +- quaint/src/ast/function.rs | 18 +++++++++++++----- quaint/src/ast/function/json_extract.rs | 8 ++++---- quaint/src/ast/insert.rs | 2 +- quaint/src/ast/update.rs | 2 +- quaint/src/visitor.rs | 8 ++++---- quaint/src/visitor/mysql.rs | 1 - quaint/src/visitor/postgres.rs | 1 - .../src/database/operations/read.rs | 6 ++++++ .../src/database/operations/write.rs | 18 +++++++++++++++--- .../sql-query-connector/src/filter/visitor.rs | 19 +++++++++++++++++++ .../sql-query-connector/src/ordering.rs | 3 +++ .../src/query_builder/mod.rs | 1 + query-engine/request-handlers/Cargo.toml | 5 ++++- 14 files changed, 72 insertions(+), 22 deletions(-) diff --git a/quaint/src/ast/delete.rs b/quaint/src/ast/delete.rs index e0e7316e8e99..cbc03dc02b2c 100644 --- a/quaint/src/ast/delete.rs +++ b/quaint/src/ast/delete.rs @@ -91,7 +91,7 @@ impl<'a> Delete<'a> { /// assert_eq!("DELETE FROM `users` RETURNING \"id\"", sql); /// # Ok(()) /// # } - #[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite"))] + #[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite", feature = "mysql"))] pub fn returning(mut self, columns: I) -> Self where K: Into>, diff --git a/quaint/src/ast/function.rs b/quaint/src/ast/function.rs index 246ea762b34e..85c5121cfbeb 100644 --- a/quaint/src/ast/function.rs +++ b/quaint/src/ast/function.rs @@ -3,7 +3,9 @@ mod average; mod coalesce; mod concat; mod count; +#[cfg(any(feature = "postgresql", feature = "mysql"))] mod json_array_agg; +#[cfg(any(feature = "postgresql", feature = "mysql"))] mod json_build_obj; #[cfg(any(feature = "postgresql", feature = "mysql"))] mod json_extract; @@ -30,7 +32,9 @@ pub use average::*; pub use coalesce::*; pub use concat::*; pub use count::*; +#[cfg(any(feature = "postgresql", feature = "mysql"))] pub use json_array_agg::*; +#[cfg(any(feature = "postgresql", feature = "mysql"))] pub use json_build_obj::*; #[cfg(any(feature = "postgresql", feature = "mysql"))] pub use json_extract::*; @@ -102,9 +106,9 @@ pub(crate) enum FunctionType<'a> { JsonExtractFirstArrayElem(JsonExtractFirstArrayElem<'a>), #[cfg(any(feature = "postgresql", feature = "mysql"))] JsonUnquote(JsonUnquote<'a>), - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] JsonArrayAgg(JsonArrayAgg<'a>), - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] JsonBuildObject(JsonBuildObject<'a>), #[cfg(any(feature = "postgresql", feature = "mysql"))] TextSearch(TextSearch<'a>), @@ -151,6 +155,12 @@ function!(TextSearch); #[cfg(any(feature = "postgresql", feature = "mysql"))] function!(TextSearchRelevance); +#[cfg(any(feature = "postgresql", feature = "mysql"))] +function!(JsonArrayAgg); + +#[cfg(any(feature = "postgresql", feature = "mysql"))] +function!(JsonBuildObject); + function!( RowNumber, Count, @@ -162,7 +172,5 @@ function!( Minimum, Maximum, Coalesce, - Concat, - JsonArrayAgg, - JsonBuildObject + Concat ); diff --git a/quaint/src/ast/function/json_extract.rs b/quaint/src/ast/function/json_extract.rs index f45295026c74..c6868cd03601 100644 --- a/quaint/src/ast/function/json_extract.rs +++ b/quaint/src/ast/function/json_extract.rs @@ -11,14 +11,14 @@ pub struct JsonExtract<'a> { #[derive(Debug, Clone, PartialEq, Eq)] pub enum JsonPath<'a> { - #[cfg(feature = "mysql")] + // #[cfg(feature = "mysql")] String(Cow<'a, str>), - #[cfg(feature = "postgresql")] + // #[cfg(feature = "postgresql")] Array(Vec>), } impl<'a> JsonPath<'a> { - #[cfg(feature = "mysql")] + // #[cfg(feature = "mysql")] pub fn string(string: S) -> JsonPath<'a> where S: Into>, @@ -26,7 +26,7 @@ impl<'a> JsonPath<'a> { JsonPath::String(string.into()) } - #[cfg(feature = "postgresql")] + // #[cfg(feature = "postgresql")] pub fn array(array: A) -> JsonPath<'a> where V: Into>, diff --git a/quaint/src/ast/insert.rs b/quaint/src/ast/insert.rs index cd38fff87043..7940112328c7 100644 --- a/quaint/src/ast/insert.rs +++ b/quaint/src/ast/insert.rs @@ -255,7 +255,7 @@ impl<'a> Insert<'a> { /// # Ok(()) /// # } /// ``` - #[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite"))] + #[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite", feature = "mysql"))] pub fn returning(mut self, columns: I) -> Self where K: Into>, diff --git a/quaint/src/ast/update.rs b/quaint/src/ast/update.rs index 751655bd82e1..a690f070ea38 100644 --- a/quaint/src/ast/update.rs +++ b/quaint/src/ast/update.rs @@ -149,7 +149,7 @@ impl<'a> Update<'a> { /// # Ok(()) /// # } /// ``` - #[cfg(any(feature = "postgresql", feature = "sqlite"))] + #[cfg(any(feature = "postgresql", feature = "sqlite", feature = "mysql"))] pub fn returning(mut self, columns: I) -> Self where K: Into>, diff --git a/quaint/src/visitor.rs b/quaint/src/visitor.rs index 58baa09a791f..b5dd0ca6090b 100644 --- a/quaint/src/visitor.rs +++ b/quaint/src/visitor.rs @@ -139,10 +139,10 @@ pub trait Visitor<'a> { #[cfg(any(feature = "postgresql", feature = "mysql"))] fn visit_json_unquote(&mut self, json_unquote: JsonUnquote<'a>) -> Result; - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] fn visit_json_array_agg(&mut self, array_agg: JsonArrayAgg<'a>) -> Result; - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] fn visit_json_build_object(&mut self, build_obj: JsonBuildObject<'a>) -> Result; #[cfg(any(feature = "postgresql", feature = "mysql"))] @@ -1138,11 +1138,11 @@ pub trait Visitor<'a> { FunctionType::Concat(concat) => { self.visit_concat(concat)?; } - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] FunctionType::JsonArrayAgg(array_agg) => { self.visit_json_array_agg(array_agg)?; } - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mysql"))] FunctionType::JsonBuildObject(build_obj) => { self.visit_json_build_object(build_obj)?; } diff --git a/quaint/src/visitor/mysql.rs b/quaint/src/visitor/mysql.rs index a406000cd7c0..54029e1b6e98 100644 --- a/quaint/src/visitor/mysql.rs +++ b/quaint/src/visitor/mysql.rs @@ -418,7 +418,6 @@ impl<'a> Visitor<'a> for Mysql<'a> { self.write(", ")?; match json_extract.path.clone() { - #[cfg(feature = "postgresql")] JsonPath::Array(_) => panic!("JSON path array notation is not supported for MySQL"), JsonPath::String(path) => self.visit_parameterized(Value::text(path))?, } diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index 8ab679f42701..b80162e7b5c5 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -409,7 +409,6 @@ impl<'a> Visitor<'a> for Postgres<'a> { #[cfg(any(feature = "postgresql", feature = "mysql"))] fn visit_json_extract(&mut self, json_extract: JsonExtract<'a>) -> visitor::Result { match json_extract.path { - #[cfg(feature = "mysql")] JsonPath::String(_) => panic!("JSON path string notation is not supported for Postgres"), JsonPath::Array(json_path) => { self.write("(")?; diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 13206f560776..5a49c6bf8f3e 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -21,11 +21,14 @@ pub(crate) async fn get_single_record( ctx: &Context<'_>, ) -> crate::Result> { match relation_load_strategy { + #[cfg(any(feature = "postgresql", feature = "mysql"))] RelationLoadStrategy::Join => get_single_record_joins(conn, model, filter, selected_fields, ctx).await, RelationLoadStrategy::Query => get_single_record_wo_joins(conn, model, filter, selected_fields, ctx).await, + _ => unreachable!(), } } +#[cfg(any(feature = "postgresql", feature = "mysql"))] pub(crate) async fn get_single_record_joins( conn: &dyn Queryable, model: &Model, @@ -117,13 +120,16 @@ pub(crate) async fn get_many_records( ctx: &Context<'_>, ) -> crate::Result { match relation_load_strategy { + #[cfg(any(feature = "postgresql", feature = "mysql"))] RelationLoadStrategy::Join => get_many_records_joins(conn, model, query_arguments, selected_fields, ctx).await, RelationLoadStrategy::Query => { get_many_records_wo_joins(conn, model, query_arguments, selected_fields, ctx).await } + _ => unreachable!(), } } +#[cfg(any(feature = "postgresql", feature = "mysql"))] pub(crate) async fn get_many_records_joins( conn: &dyn Queryable, _model: &Model, diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index d5c067851864..02b4f381514e 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -10,7 +10,7 @@ use connector_interface::*; use itertools::Itertools; use quaint::{ error::ErrorKind, - prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped, Aliasable, Select, SqlFamily}, + prelude::{Aliasable, Select, SqlFamily}, }; use query_structure::*; use std::borrow::Cow; @@ -41,12 +41,14 @@ macro_rules! trace { }; } +#[cfg(feature = "mysql")] async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, args: &WriteArgs, ctx: &Context<'_>, ) -> crate::Result> { + use quaint::prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped}; // Go through all the values and generate a select statement with the correct MySQL function let (id_select, need_select) = id_field .selections() @@ -82,6 +84,16 @@ async fn generate_id( } } +#[cfg(not(feature = "mysql"))] +async fn generate_id( + _conn: &dyn Queryable, + _id_field: &FieldSelection, + _args: &WriteArgs, + _ctx: &Context<'_>, +) -> crate::Result> { + Ok(None) +} + /// Create a single record to the database defined in `conn`, resulting into a /// `RecordProjection` as an identifier pointing to the just-created record. pub(crate) async fn create_record( @@ -94,7 +106,7 @@ pub(crate) async fn create_record( ) -> crate::Result { let id_field: FieldSelection = model.primary_identifier(); - let returned_id = if *sql_family == SqlFamily::Mysql { + let returned_id = if sql_family.is_mysql() { generate_id(conn, &id_field, &args, ctx) .await? .or_else(|| args.as_selection_result(ModelProjection::from(id_field))) @@ -103,7 +115,7 @@ pub(crate) async fn create_record( }; let args = match returned_id { - Some(ref pk) if *sql_family == SqlFamily::Mysql => { + Some(ref pk) if sql_family.is_mysql() => { for (field, value) in pk.pairs.iter() { let field = DatasourceFieldName(field.db_name().into()); let value = WriteOperation::scalar_set(value.clone()); diff --git a/query-engine/connectors/sql-query-connector/src/filter/visitor.rs b/query-engine/connectors/sql-query-connector/src/filter/visitor.rs index b27ab539e604..1325ce130772 100644 --- a/query-engine/connectors/sql-query-connector/src/filter/visitor.rs +++ b/query-engine/connectors/sql-query-connector/src/filter/visitor.rs @@ -352,6 +352,7 @@ impl FilterVisitorExt for FilterVisitor { fn visit_scalar_filter(&mut self, filter: ScalarFilter, ctx: &Context<'_>) -> ConditionTree<'static> { match filter.condition { + #[cfg(any(feature = "postgresql", feature = "mysql"))] ScalarCondition::Search(_, _) | ScalarCondition::NotSearch(_, _) => { let mut projections = match filter.condition.clone() { ScalarCondition::Search(_, proj) => proj, @@ -610,6 +611,7 @@ impl FilterVisitorExt for FilterVisitor { ScalarListCondition::Contains(ConditionValue::Value(val)) => { comparable.compare_raw("@>", convert_list_pv(field, vec![val], ctx)) } + #[cfg(feature = "postgresql")] ScalarListCondition::Contains(ConditionValue::FieldRef(field_ref)) => { let field_ref_expr: Expression = field_ref.aliased_col(alias, ctx).into(); @@ -630,6 +632,7 @@ impl FilterVisitorExt for FilterVisitor { } ScalarListCondition::IsEmpty(true) => comparable.compare_raw("=", ValueType::Array(Some(vec![])).raw()), ScalarListCondition::IsEmpty(false) => comparable.compare_raw("<>", ValueType::Array(Some(vec![])).raw()), + _ => unreachable!(), }; ConditionTree::single(condition) @@ -705,6 +708,7 @@ fn convert_scalar_filter( ctx: &Context<'_>, ) -> ConditionTree<'static> { match cond { + #[cfg(any(feature = "postgresql", feature = "mysql"))] ScalarCondition::JsonCompare(json_compare) => convert_json_filter( comparable, json_compare, @@ -723,6 +727,7 @@ fn convert_scalar_filter( } } +#[cfg(any(feature = "mysql", feature = "postgresql"))] fn convert_json_filter( comparable: Expression<'static>, json_condition: JsonCondition, @@ -809,6 +814,7 @@ fn convert_json_filter( ConditionTree::single(condition) } +#[cfg(any(feature = "postgresql", feature = "mysql"))] fn with_json_type_filter( comparable: Compare<'static>, expr_json: Expression<'static>, @@ -928,6 +934,7 @@ fn default_scalar_filter( } _ => comparable.in_selection(convert_pvs(fields, values, ctx)), }, + #[cfg(any(feature = "postgresql"))] ScalarCondition::In(ConditionListValue::FieldRef(field_ref)) => { // This code path is only reachable for connectors with `ScalarLists` capability comparable.equals(Expression::from(field_ref.aliased_col(alias, ctx)).any()) @@ -945,10 +952,12 @@ fn default_scalar_filter( } _ => comparable.not_in_selection(convert_pvs(fields, values, ctx)), }, + #[cfg(any(feature = "postgresql"))] ScalarCondition::NotIn(ConditionListValue::FieldRef(field_ref)) => { // This code path is only reachable for connectors with `ScalarLists` capability comparable.not_equals(Expression::from(field_ref.aliased_col(alias, ctx)).all()) } + #[cfg(any(feature = "postgresql"))] ScalarCondition::Search(value, _) => { let query: String = value .into_value() @@ -958,6 +967,7 @@ fn default_scalar_filter( comparable.matches(query) } + #[cfg(any(feature = "postgresql"))] ScalarCondition::NotSearch(value, _) => { let query: String = value .into_value() @@ -969,6 +979,8 @@ fn default_scalar_filter( } ScalarCondition::JsonCompare(_) => unreachable!(), ScalarCondition::IsSet(_) => unreachable!(), + + _ => unreachable!(), }; ConditionTree::single(condition) @@ -1094,6 +1106,7 @@ fn insensitive_scalar_filter( ) } }, + #[cfg(feature = "postgresql")] ScalarCondition::In(ConditionListValue::FieldRef(field_ref)) => { // This code path is only reachable for connectors with `ScalarLists` capability comparable.compare_raw("ILIKE", Expression::from(field_ref.aliased_col(alias, ctx)).any()) @@ -1125,10 +1138,12 @@ fn insensitive_scalar_filter( ) } }, + #[cfg(any(feature = "postgresql"))] ScalarCondition::NotIn(ConditionListValue::FieldRef(field_ref)) => { // This code path is only reachable for connectors with `ScalarLists` capability comparable.compare_raw("NOT ILIKE", Expression::from(field_ref.aliased_col(alias, ctx)).all()) } + #[cfg(any(feature = "postgresql"))] ScalarCondition::Search(value, _) => { let query: String = value .into_value() @@ -1138,6 +1153,7 @@ fn insensitive_scalar_filter( comparable.matches(query) } + #[cfg(any(feature = "postgresql"))] ScalarCondition::NotSearch(value, _) => { let query: String = value .into_value() @@ -1149,6 +1165,7 @@ fn insensitive_scalar_filter( } ScalarCondition::JsonCompare(_) => unreachable!(), ScalarCondition::IsSet(_) => unreachable!(), + _ => unreachable!(), }; ConditionTree::single(condition) @@ -1207,6 +1224,7 @@ fn convert_pvs<'a>(fields: &[ScalarFieldRef], values: Vec, ctx: &Co } } +#[cfg(any(feature = "postgresql", feature = "mysql"))] trait JsonFilterExt { fn json_contains( self, @@ -1239,6 +1257,7 @@ trait JsonFilterExt { ) -> Expression<'static>; } +#[cfg(any(feature = "postgresql", feature = "mysql"))] impl JsonFilterExt for (Expression<'static>, Expression<'static>) { fn json_contains( self, diff --git a/query-engine/connectors/sql-query-connector/src/ordering.rs b/query-engine/connectors/sql-query-connector/src/ordering.rs index ade10aaa7164..0a3716510953 100644 --- a/query-engine/connectors/sql-query-connector/src/ordering.rs +++ b/query-engine/connectors/sql-query-connector/src/ordering.rs @@ -44,7 +44,9 @@ impl OrderByBuilder { self.build_order_aggr_scalar(order_by, needs_reversed_order, ctx) } OrderBy::ToManyAggregation(order_by) => self.build_order_aggr_rel(order_by, needs_reversed_order, ctx), + #[cfg(any(feature = "postgresql", feature = "mysql"))] OrderBy::Relevance(order_by) => self.build_order_relevance(order_by, needs_reversed_order, ctx), + _ => unreachable!(), }) .collect_vec() } @@ -70,6 +72,7 @@ impl OrderByBuilder { } } + #[cfg(any(feature = "postgresql", feature = "mysql"))] fn build_order_relevance( &mut self, order_by: &OrderByRelevance, diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs b/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs index 7f16b84f95fd..0d49c1ca7676 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod read; +#[cfg(any(feature = "postgresql", feature = "mysql"))] pub(crate) mod select; pub(crate) mod write; diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index 4200980202c5..b27b99b9126d 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -24,7 +24,7 @@ connection-string.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } -sql-query-connector = { path = "../connectors/sql-query-connector", optional = true } +sql-query-connector = { path = "../connectors/sql-query-connector", optional = true, default-features = false } [dev-dependencies] insta = "1.7.1" @@ -35,6 +35,9 @@ codspeed-criterion-compat = "1.1.0" default = ["sql", "mongodb", "native", "graphql-protocol"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] +postgresql = ["sql-query-connector", "sql-query-connector/postgresql"] +mysql = ["sql-query-connector", "sql-query-connector/mysql"] +sqlite = ["sql-query-connector", "sql-query-connector/sqlite"] driver-adapters = ["sql-query-connector/driver-adapters"] native = [ "mongodb",