Skip to content

Commit

Permalink
qe: Remove more per-connector code after splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Tatarintsev committed Feb 14, 2024
1 parent ed5ca1a commit c54730b
Show file tree
Hide file tree
Showing 14 changed files with 72 additions and 22 deletions.
2 changes: 1 addition & 1 deletion quaint/src/ast/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<'a> Delete<'a> {
/// assert_eq!("DELETE FROM `users` RETURNING \"id\"", sql);
/// # Ok(())
/// # }
#[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite"))]
#[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite", feature = "mysql"))]
pub fn returning<K, I>(mut self, columns: I) -> Self
where
K: Into<Column<'a>>,
Expand Down
18 changes: 13 additions & 5 deletions quaint/src/ast/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ mod average;
mod coalesce;
mod concat;
mod count;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
mod json_array_agg;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
mod json_build_obj;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
mod json_extract;
Expand All @@ -30,7 +32,9 @@ pub use average::*;
pub use coalesce::*;
pub use concat::*;
pub use count::*;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
pub use json_array_agg::*;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
pub use json_build_obj::*;
#[cfg(any(feature = "postgresql", feature = "mysql"))]
pub use json_extract::*;
Expand Down Expand Up @@ -102,9 +106,9 @@ pub(crate) enum FunctionType<'a> {
JsonExtractFirstArrayElem(JsonExtractFirstArrayElem<'a>),
#[cfg(any(feature = "postgresql", feature = "mysql"))]
JsonUnquote(JsonUnquote<'a>),
#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
JsonArrayAgg(JsonArrayAgg<'a>),
#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
JsonBuildObject(JsonBuildObject<'a>),
#[cfg(any(feature = "postgresql", feature = "mysql"))]
TextSearch(TextSearch<'a>),
Expand Down Expand Up @@ -151,6 +155,12 @@ function!(TextSearch);
#[cfg(any(feature = "postgresql", feature = "mysql"))]
function!(TextSearchRelevance);

#[cfg(any(feature = "postgresql", feature = "mysql"))]
function!(JsonArrayAgg);

#[cfg(any(feature = "postgresql", feature = "mysql"))]
function!(JsonBuildObject);

function!(
RowNumber,
Count,
Expand All @@ -162,7 +172,5 @@ function!(
Minimum,
Maximum,
Coalesce,
Concat,
JsonArrayAgg,
JsonBuildObject
Concat
);
8 changes: 4 additions & 4 deletions quaint/src/ast/function/json_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ pub struct JsonExtract<'a> {

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JsonPath<'a> {
#[cfg(feature = "mysql")]
// #[cfg(feature = "mysql")]
String(Cow<'a, str>),
#[cfg(feature = "postgresql")]
// #[cfg(feature = "postgresql")]
Array(Vec<Cow<'a, str>>),
}

impl<'a> JsonPath<'a> {
#[cfg(feature = "mysql")]
// #[cfg(feature = "mysql")]
pub fn string<S>(string: S) -> JsonPath<'a>
where
S: Into<Cow<'a, str>>,
{
JsonPath::String(string.into())
}

#[cfg(feature = "postgresql")]
// #[cfg(feature = "postgresql")]
pub fn array<A, V>(array: A) -> JsonPath<'a>
where
V: Into<Cow<'a, str>>,
Expand Down
2 changes: 1 addition & 1 deletion quaint/src/ast/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl<'a> Insert<'a> {
/// # Ok(())
/// # }
/// ```
#[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite"))]
#[cfg(any(feature = "postgresql", feature = "mssql", feature = "sqlite", feature = "mysql"))]
pub fn returning<K, I>(mut self, columns: I) -> Self
where
K: Into<Column<'a>>,
Expand Down
2 changes: 1 addition & 1 deletion quaint/src/ast/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl<'a> Update<'a> {
/// # Ok(())
/// # }
/// ```
#[cfg(any(feature = "postgresql", feature = "sqlite"))]
#[cfg(any(feature = "postgresql", feature = "sqlite", feature = "mysql"))]
pub fn returning<K, I>(mut self, columns: I) -> Self
where
K: Into<Column<'a>>,
Expand Down
8 changes: 4 additions & 4 deletions quaint/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ pub trait Visitor<'a> {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn visit_json_unquote(&mut self, json_unquote: JsonUnquote<'a>) -> Result;

#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn visit_json_array_agg(&mut self, array_agg: JsonArrayAgg<'a>) -> Result;

#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn visit_json_build_object(&mut self, build_obj: JsonBuildObject<'a>) -> Result;

#[cfg(any(feature = "postgresql", feature = "mysql"))]
Expand Down Expand Up @@ -1138,11 +1138,11 @@ pub trait Visitor<'a> {
FunctionType::Concat(concat) => {
self.visit_concat(concat)?;
}
#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
FunctionType::JsonArrayAgg(array_agg) => {
self.visit_json_array_agg(array_agg)?;
}
#[cfg(feature = "postgresql")]
#[cfg(any(feature = "postgresql", feature = "mysql"))]
FunctionType::JsonBuildObject(build_obj) => {
self.visit_json_build_object(build_obj)?;
}
Expand Down
1 change: 0 additions & 1 deletion quaint/src/visitor/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ impl<'a> Visitor<'a> for Mysql<'a> {
self.write(", ")?;

match json_extract.path.clone() {
#[cfg(feature = "postgresql")]
JsonPath::Array(_) => panic!("JSON path array notation is not supported for MySQL"),
JsonPath::String(path) => self.visit_parameterized(Value::text(path))?,
}
Expand Down
1 change: 0 additions & 1 deletion quaint/src/visitor/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ impl<'a> Visitor<'a> for Postgres<'a> {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn visit_json_extract(&mut self, json_extract: JsonExtract<'a>) -> visitor::Result {
match json_extract.path {
#[cfg(feature = "mysql")]
JsonPath::String(_) => panic!("JSON path string notation is not supported for Postgres"),
JsonPath::Array(json_path) => {
self.write("(")?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ pub(crate) async fn get_single_record(
ctx: &Context<'_>,
) -> crate::Result<Option<SingleRecord>> {
match relation_load_strategy {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
RelationLoadStrategy::Join => get_single_record_joins(conn, model, filter, selected_fields, ctx).await,
RelationLoadStrategy::Query => get_single_record_wo_joins(conn, model, filter, selected_fields, ctx).await,
_ => unreachable!(),
}
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
pub(crate) async fn get_single_record_joins(
conn: &dyn Queryable,
model: &Model,
Expand Down Expand Up @@ -117,13 +120,16 @@ pub(crate) async fn get_many_records(
ctx: &Context<'_>,
) -> crate::Result<ManyRecords> {
match relation_load_strategy {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
RelationLoadStrategy::Join => get_many_records_joins(conn, model, query_arguments, selected_fields, ctx).await,
RelationLoadStrategy::Query => {
get_many_records_wo_joins(conn, model, query_arguments, selected_fields, ctx).await
}
_ => unreachable!(),
}
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
pub(crate) async fn get_many_records_joins(
conn: &dyn Queryable,
_model: &Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use connector_interface::*;
use itertools::Itertools;
use quaint::{
error::ErrorKind,
prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped, Aliasable, Select, SqlFamily},
prelude::{Aliasable, Select, SqlFamily},
};
use query_structure::*;
use std::borrow::Cow;
Expand Down Expand Up @@ -41,12 +41,14 @@ macro_rules! trace {
};
}

#[cfg(feature = "mysql")]
async fn generate_id(
conn: &dyn Queryable,
id_field: &FieldSelection,
args: &WriteArgs,
ctx: &Context<'_>,
) -> crate::Result<Option<SelectionResult>> {
use quaint::prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped};
// Go through all the values and generate a select statement with the correct MySQL function
let (id_select, need_select) = id_field
.selections()
Expand Down Expand Up @@ -82,6 +84,16 @@ async fn generate_id(
}
}

#[cfg(not(feature = "mysql"))]
async fn generate_id(
_conn: &dyn Queryable,
_id_field: &FieldSelection,
_args: &WriteArgs,
_ctx: &Context<'_>,
) -> crate::Result<Option<SelectionResult>> {
Ok(None)
}

/// Create a single record to the database defined in `conn`, resulting into a
/// `RecordProjection` as an identifier pointing to the just-created record.
pub(crate) async fn create_record(
Expand All @@ -94,7 +106,7 @@ pub(crate) async fn create_record(
) -> crate::Result<SingleRecord> {
let id_field: FieldSelection = model.primary_identifier();

let returned_id = if *sql_family == SqlFamily::Mysql {
let returned_id = if sql_family.is_mysql() {
generate_id(conn, &id_field, &args, ctx)
.await?
.or_else(|| args.as_selection_result(ModelProjection::from(id_field)))
Expand All @@ -103,7 +115,7 @@ pub(crate) async fn create_record(
};

let args = match returned_id {
Some(ref pk) if *sql_family == SqlFamily::Mysql => {
Some(ref pk) if sql_family.is_mysql() => {
for (field, value) in pk.pairs.iter() {
let field = DatasourceFieldName(field.db_name().into());
let value = WriteOperation::scalar_set(value.clone());
Expand Down
19 changes: 19 additions & 0 deletions query-engine/connectors/sql-query-connector/src/filter/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ impl FilterVisitorExt for FilterVisitor {

fn visit_scalar_filter(&mut self, filter: ScalarFilter, ctx: &Context<'_>) -> ConditionTree<'static> {
match filter.condition {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
ScalarCondition::Search(_, _) | ScalarCondition::NotSearch(_, _) => {
let mut projections = match filter.condition.clone() {
ScalarCondition::Search(_, proj) => proj,
Expand Down Expand Up @@ -610,6 +611,7 @@ impl FilterVisitorExt for FilterVisitor {
ScalarListCondition::Contains(ConditionValue::Value(val)) => {
comparable.compare_raw("@>", convert_list_pv(field, vec![val], ctx))
}
#[cfg(feature = "postgresql")]
ScalarListCondition::Contains(ConditionValue::FieldRef(field_ref)) => {
let field_ref_expr: Expression = field_ref.aliased_col(alias, ctx).into();

Expand All @@ -630,6 +632,7 @@ impl FilterVisitorExt for FilterVisitor {
}
ScalarListCondition::IsEmpty(true) => comparable.compare_raw("=", ValueType::Array(Some(vec![])).raw()),
ScalarListCondition::IsEmpty(false) => comparable.compare_raw("<>", ValueType::Array(Some(vec![])).raw()),
_ => unreachable!(),
};

ConditionTree::single(condition)
Expand Down Expand Up @@ -705,6 +708,7 @@ fn convert_scalar_filter(
ctx: &Context<'_>,
) -> ConditionTree<'static> {
match cond {
#[cfg(any(feature = "postgresql", feature = "mysql"))]
ScalarCondition::JsonCompare(json_compare) => convert_json_filter(
comparable,
json_compare,
Expand All @@ -723,6 +727,7 @@ fn convert_scalar_filter(
}
}

#[cfg(any(feature = "mysql", feature = "postgresql"))]
fn convert_json_filter(
comparable: Expression<'static>,
json_condition: JsonCondition,
Expand Down Expand Up @@ -809,6 +814,7 @@ fn convert_json_filter(
ConditionTree::single(condition)
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn with_json_type_filter(
comparable: Compare<'static>,
expr_json: Expression<'static>,
Expand Down Expand Up @@ -928,6 +934,7 @@ fn default_scalar_filter(
}
_ => comparable.in_selection(convert_pvs(fields, values, ctx)),
},
#[cfg(any(feature = "postgresql"))]
ScalarCondition::In(ConditionListValue::FieldRef(field_ref)) => {
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.equals(Expression::from(field_ref.aliased_col(alias, ctx)).any())
Expand All @@ -945,10 +952,12 @@ fn default_scalar_filter(
}
_ => comparable.not_in_selection(convert_pvs(fields, values, ctx)),
},
#[cfg(any(feature = "postgresql"))]
ScalarCondition::NotIn(ConditionListValue::FieldRef(field_ref)) => {
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.not_equals(Expression::from(field_ref.aliased_col(alias, ctx)).all())
}
#[cfg(any(feature = "postgresql"))]
ScalarCondition::Search(value, _) => {
let query: String = value
.into_value()
Expand All @@ -958,6 +967,7 @@ fn default_scalar_filter(

comparable.matches(query)
}
#[cfg(any(feature = "postgresql"))]
ScalarCondition::NotSearch(value, _) => {
let query: String = value
.into_value()
Expand All @@ -969,6 +979,8 @@ fn default_scalar_filter(
}
ScalarCondition::JsonCompare(_) => unreachable!(),
ScalarCondition::IsSet(_) => unreachable!(),

_ => unreachable!(),
};

ConditionTree::single(condition)
Expand Down Expand Up @@ -1094,6 +1106,7 @@ fn insensitive_scalar_filter(
)
}
},
#[cfg(feature = "postgresql")]
ScalarCondition::In(ConditionListValue::FieldRef(field_ref)) => {
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.compare_raw("ILIKE", Expression::from(field_ref.aliased_col(alias, ctx)).any())
Expand Down Expand Up @@ -1125,10 +1138,12 @@ fn insensitive_scalar_filter(
)
}
},
#[cfg(any(feature = "postgresql"))]
ScalarCondition::NotIn(ConditionListValue::FieldRef(field_ref)) => {
// This code path is only reachable for connectors with `ScalarLists` capability
comparable.compare_raw("NOT ILIKE", Expression::from(field_ref.aliased_col(alias, ctx)).all())
}
#[cfg(any(feature = "postgresql"))]
ScalarCondition::Search(value, _) => {
let query: String = value
.into_value()
Expand All @@ -1138,6 +1153,7 @@ fn insensitive_scalar_filter(

comparable.matches(query)
}
#[cfg(any(feature = "postgresql"))]
ScalarCondition::NotSearch(value, _) => {
let query: String = value
.into_value()
Expand All @@ -1149,6 +1165,7 @@ fn insensitive_scalar_filter(
}
ScalarCondition::JsonCompare(_) => unreachable!(),
ScalarCondition::IsSet(_) => unreachable!(),
_ => unreachable!(),
};

ConditionTree::single(condition)
Expand Down Expand Up @@ -1207,6 +1224,7 @@ fn convert_pvs<'a>(fields: &[ScalarFieldRef], values: Vec<PrismaValue>, ctx: &Co
}
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
trait JsonFilterExt {
fn json_contains(
self,
Expand Down Expand Up @@ -1239,6 +1257,7 @@ trait JsonFilterExt {
) -> Expression<'static>;
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
impl JsonFilterExt for (Expression<'static>, Expression<'static>) {
fn json_contains(
self,
Expand Down
3 changes: 3 additions & 0 deletions query-engine/connectors/sql-query-connector/src/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ impl OrderByBuilder {
self.build_order_aggr_scalar(order_by, needs_reversed_order, ctx)
}
OrderBy::ToManyAggregation(order_by) => self.build_order_aggr_rel(order_by, needs_reversed_order, ctx),
#[cfg(any(feature = "postgresql", feature = "mysql"))]
OrderBy::Relevance(order_by) => self.build_order_relevance(order_by, needs_reversed_order, ctx),
_ => unreachable!(),
})
.collect_vec()
}
Expand All @@ -70,6 +72,7 @@ impl OrderByBuilder {
}
}

#[cfg(any(feature = "postgresql", feature = "mysql"))]
fn build_order_relevance(
&mut self,
order_by: &OrderByRelevance,
Expand Down
Loading

0 comments on commit c54730b

Please sign in to comment.