diff --git a/query-engine/core/src/compiler/expression.rs b/query-engine/core/src/compiler/expression.rs index 65b5d22bcd0..1a3a3606bb6 100644 --- a/query-engine/core/src/compiler/expression.rs +++ b/query-engine/core/src/compiler/expression.rs @@ -91,6 +91,20 @@ pub enum Expression { MapField { field: String, records: Box }, } +#[derive(Debug, Clone)] +pub enum ExpressionType { + Scalar, + Record, + List(Box), + Dynamic, +} + +impl ExpressionType { + pub fn is_list(&self) -> bool { + matches!(self, ExpressionType::List(_) | ExpressionType::Dynamic) + } +} + #[derive(thiserror::Error, Debug)] pub enum PrettyPrintError { #[error("{0}")] @@ -114,6 +128,29 @@ impl Expression { doc.render_colored(width, &mut buf)?; Ok(String::from_utf8(buf.into_inner())?) } + + pub fn r#type(&self) -> ExpressionType { + match self { + Expression::Seq(vec) => vec.iter().last().map_or(ExpressionType::Scalar, Expression::r#type), + Expression::Get { .. } => ExpressionType::Dynamic, + Expression::Let { expr, .. } => expr.r#type(), + Expression::GetFirstNonEmpty { .. } => ExpressionType::Dynamic, + Expression::Query(_) => ExpressionType::List(Box::new(ExpressionType::Record)), + Expression::Execute(_) => ExpressionType::Scalar, + Expression::Reverse(expression) => expression.r#type(), + Expression::Sum(_) => ExpressionType::Scalar, + Expression::Concat(vec) => ExpressionType::List(Box::new( + vec.iter().last().map_or(ExpressionType::Scalar, Expression::r#type), + )), + Expression::Unique(expression) => match expression.r#type() { + ExpressionType::List(inner) => inner.as_ref().clone(), + _ => expression.r#type(), + }, + Expression::Required(expression) => expression.r#type(), + Expression::Join { parent, .. } => parent.r#type(), + Expression::MapField { records, .. } => records.r#type(), + } + } } impl std::fmt::Display for Expression { diff --git a/query-engine/core/src/compiler/translate/query/read.rs b/query-engine/core/src/compiler/translate/query/read.rs index 3fddda418e9..a2f51b5aac8 100644 --- a/query-engine/core/src/compiler/translate/query/read.rs +++ b/query-engine/core/src/compiler/translate/query/read.rs @@ -2,8 +2,8 @@ use std::collections::HashSet; use itertools::Itertools; use query_structure::{ - ConditionValue, Filter, ModelProjection, PlaceholderType, PrismaValue, QueryMode, RelationField, ScalarCondition, - ScalarField, ScalarFilter, ScalarProjection, + ConditionListValue, ConditionValue, Filter, ModelProjection, PlaceholderType, PrismaValue, QueryMode, + RelationField, ScalarCondition, ScalarField, ScalarFilter, ScalarProjection, SelectedField, SelectionResult, }; use sql_query_connector::{ context::Context, model_extensions::AsColumns, query_arguments_ext::QueryArgumentsExt, query_builder, @@ -113,7 +113,7 @@ fn add_inmemory_join(parent: Expression, nested: Vec, ctx: &Context<' ReadQuery::RelatedRecordsQuery(rrq) => Some(rrq), _ => None, }) - .map(|rrq| -> TranslateResult { + .map(|mut rrq| -> TranslateResult { let parent_field_name = rrq.parent_field.name().to_owned(); let parent_fields = rrq.parent_field.linking_fields(); let child_fields = rrq.parent_field.related_field().linking_fields(); @@ -124,14 +124,53 @@ fn add_inmemory_join(parent: Expression, nested: Vec, ctx: &Context<' .map(|(left, right)| (left.name().to_owned(), right.name().to_owned())) .collect_vec(); - // nested.add_filter(Filter::Scalar(ScalarFilter { - // mode: QueryMode::Default, - // condition: ScalarCondition::Equals(ConditionValue::value(PrismaValue::placeholder( - // "parent_id".into(), - // PlaceholderType::String, - // ))), - // projection: ScalarProjection::Compound(referenced_fields), - // })); + // let linking_placeholders = parent_fields + // .scalars() + // .map(|sf| { + // ( + // sf.clone(), + // PrismaValue::placeholder( + // format!("@parent${}", sf.name()), + // sf.type_identifier().to_placeholder_type(), + // ), + // ) + // }) + // .collect::>(); + // + // // If constant values were already provided for some of the fields, merge the + // // placeholders for the missing fields. Otherwise, assign new `parent_results`. + // if let Some(parent_results) = &mut rrq.parent_results { + // for result in parent_results { + // for (sf, value) in &linking_placeholders { + // let field = SelectedField::from(sf.clone()); + // if result.get(&field).is_none() { + // result.add((field, value.clone())); + // } + // } + // } + // } else { + // rrq.parent_results = Some(vec![SelectionResult::new(linking_placeholders)]); + // } + + for (parent_field, child_field) in parent_fields.scalars().zip(child_fields.scalars()) { + let placeholder = PrismaValue::placeholder( + format!("@parent${}", parent_field.name()), + parent_field.type_identifier().to_placeholder_type(), + ); + + let condition = if parent.r#type().is_list() { + ScalarCondition::In(ConditionListValue::list(vec![placeholder])) + } else { + ScalarCondition::Equals(ConditionValue::value(placeholder)) + }; + + rrq.add_filter(Filter::Scalar(ScalarFilter { + condition, + projection: ScalarProjection::Single(child_field.clone()), + mode: QueryMode::Default, + })); + } + let child_query = translate_read_query(ReadQuery::RelatedRecordsQuery(rrq), ctx)?; Ok(JoinExpression { diff --git a/query-engine/core/src/query_ast/read.rs b/query-engine/core/src/query_ast/read.rs index e3eca8c88ee..2326183b7b5 100644 --- a/query-engine/core/src/query_ast/read.rs +++ b/query-engine/core/src/query_ast/read.rs @@ -250,3 +250,13 @@ impl FilteredQuery for ManyRecordsQuery { self.args.filter = Some(filter) } } + +impl FilteredQuery for RelatedRecordsQuery { + fn get_filter(&mut self) -> Option<&mut Filter> { + self.args.filter.as_mut() + } + + fn set_filter(&mut self, filter: Filter) { + self.args.filter = Some(filter) + } +} diff --git a/query-engine/query-structure/src/field/mod.rs b/query-engine/query-structure/src/field/mod.rs index d8faf404e66..93b9ccd98d8 100644 --- a/query-engine/query-structure/src/field/mod.rs +++ b/query-engine/query-structure/src/field/mod.rs @@ -3,6 +3,7 @@ mod relation; mod scalar; pub use composite::*; +use prisma_value::PlaceholderType; pub use relation::*; pub use scalar::*; @@ -179,6 +180,23 @@ impl TypeIdentifier { } } + pub fn to_placeholder_type(&self) -> PlaceholderType { + match self { + TypeIdentifier::String => PlaceholderType::String, + TypeIdentifier::Int => PlaceholderType::Int, + TypeIdentifier::BigInt => PlaceholderType::BigInt, + TypeIdentifier::Float => PlaceholderType::Float, + TypeIdentifier::Decimal => PlaceholderType::Decimal, + TypeIdentifier::Boolean => PlaceholderType::Boolean, + TypeIdentifier::Enum(_) => PlaceholderType::String, + TypeIdentifier::UUID => PlaceholderType::String, + TypeIdentifier::Json => PlaceholderType::Object, + TypeIdentifier::DateTime => PlaceholderType::Date, + TypeIdentifier::Bytes => PlaceholderType::Bytes, + TypeIdentifier::Unsupported => PlaceholderType::Any, + } + } + /// Returns `true` if the type identifier is [`Enum`]. pub fn is_enum(&self) -> bool { matches!(self, Self::Enum(..))