From 3b48a79691098e715f13053dc7d9158237dd57fb Mon Sep 17 00:00:00 2001 From: Alexey Orlenko Date: Thu, 26 Dec 2024 18:58:04 +0100 Subject: [PATCH] first pass at read nested --- libs/prisma-value/src/lib.rs | 4 + query-engine/core/src/compiler/expression.rs | 58 ++++++++- .../core/src/compiler/translate/query/read.rs | 118 ++++++++++++++++-- query-engine/core/src/query_ast/read.rs | 16 +++ .../query-engine/examples/compiler.rs | 10 +- 5 files changed, 194 insertions(+), 12 deletions(-) diff --git a/libs/prisma-value/src/lib.rs b/libs/prisma-value/src/lib.rs index 01a4e5e50572..81adccca2661 100644 --- a/libs/prisma-value/src/lib.rs +++ b/libs/prisma-value/src/lib.rs @@ -381,6 +381,10 @@ impl PrismaValue { PrismaValue::DateTime(parse_datetime(datetime).unwrap()) } + pub fn placeholder(name: String, r#type: PlaceholderType) -> PrismaValue { + PrismaValue::Placeholder { name, r#type } + } + pub fn as_boolean(&self) -> Option<&bool> { match self { PrismaValue::Boolean(bool) => Some(bool), diff --git a/query-engine/core/src/compiler/expression.rs b/query-engine/core/src/compiler/expression.rs index 26e6e066be55..5669e82a695b 100644 --- a/query-engine/core/src/compiler/expression.rs +++ b/query-engine/core/src/compiler/expression.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use query_structure::PrismaValue; use serde::Serialize; @@ -32,7 +33,13 @@ impl DbQuery { } #[derive(Debug, Serialize)] -#[serde(tag = "type", content = "args")] +pub struct JoinExpression { + pub child: Expression, + pub on: Vec<(String, String)>, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", content = "args", rename_all = "camelCase")] pub enum Expression { /// Sequence of statements. The whole sequence evaluates to the result of the last expression. Seq(Vec), @@ -63,6 +70,22 @@ pub enum Expression { /// Concatenates a list of lists. Concat(Vec), + + /// Asserts that the result of the expression is at most one record. + Unique(Box), + + /// Asserts that the result of the expression is at least one record. + Required(Box), + + /// Application-level join. + Join { + parent: Box, + children: Vec, + }, + + /// Get a field from a record or records. If the argument is a list of records, + /// returns a list of values of this field. + MapField { field: String, records: Box }, } impl Expression { @@ -114,6 +137,37 @@ impl Expression { Self::Sum(exprs) => self.display_function("sum", exprs, f, level)?, Self::Concat(exprs) => self.display_function("concat", exprs, f, level)?, + + Self::Unique(expr) => { + writeln!(f, "{indent}unique (")?; + expr.display(f, level + 1)?; + write!(f, "{indent})")?; + } + + Self::Required(expr) => { + writeln!(f, "{indent}required (")?; + expr.display(f, level + 1)?; + write!(f, "{indent})")?; + } + + Self::Join { parent, children } => { + writeln!(f, "{indent}join (")?; + parent.display(f, level + 1)?; + for nested in children { + let left = nested.on.iter().map(|(l, _)| l).cloned().join(", "); + let right = nested.on.iter().map(|(_, r)| r).cloned().join(", "); + writeln!(f, "\n{indent} with (")?; + nested.child.display(f, level + 2)?; + writeln!(f, "\n{indent} ) on left.{left} = right.{right},")?; + } + write!(f, "{indent})")?; + } + + Self::MapField { field, records } => { + writeln!(f, "{indent}mapField {field} (")?; + records.display(f, level + 1)?; + write!(f, "\n{indent})")?; + } } Ok(()) @@ -128,7 +182,7 @@ impl Expression { ) -> std::fmt::Result { let indent = " ".repeat(level); let DbQuery { query, params } = db_query; - write!(f, "{indent}{op} {{\n{indent} {query}\n{indent}}} with {params:?}") + write!(f, "{indent}{op} (\n{indent} {query}\n{indent}) with {params:?}") } fn display_function( diff --git a/query-engine/core/src/compiler/translate/query/read.rs b/query-engine/core/src/compiler/translate/query/read.rs index 076d4379566a..6a27080d7d07 100644 --- a/query-engine/core/src/compiler/translate/query/read.rs +++ b/query-engine/core/src/compiler/translate/query/read.rs @@ -1,16 +1,30 @@ -use query_structure::ModelProjection; +use std::collections::HashSet; + +use itertools::Itertools; +use query_structure::{ + ConditionValue, Filter, ModelProjection, PlaceholderType, PrismaValue, QueryMode, RelationField, ScalarCondition, + ScalarField, ScalarFilter, ScalarProjection, +}; use sql_query_connector::{ context::Context, model_extensions::AsColumns, query_arguments_ext::QueryArgumentsExt, query_builder, }; use crate::{ - compiler::{expression::Expression, translate::TranslateResult}, - ReadQuery, RelatedRecordsQuery, + compiler::{ + expression::{Binding, Expression, JoinExpression}, + translate::TranslateResult, + }, + FilteredQuery, ReadQuery, RelatedRecordsQuery, }; use super::build_db_query; pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> TranslateResult { + let all_linking_fields = query + .nested_related_records_queries() + .flat_map(|rrq| rrq.parent_field.linking_fields()) + .collect::>(); + Ok(match query { ReadQuery::RecordQuery(rq) => { let selected_fields = rq.selected_fields.without_relations().into_virtuals_last(); @@ -26,7 +40,66 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans ) .limit(1); - Expression::Query(build_db_query(query)?) + let expr = Expression::Query(build_db_query(query)?); + + if rq.nested.is_empty() { + return Ok(expr); + } + + Expression::Let { + bindings: vec![Binding { + name: "@parent".into(), + expr, + }], + expr: Box::new(Expression::Let { + bindings: all_linking_fields + .into_iter() + .map(|sf| Binding { + name: format!("@parent.{}", sf.prisma_name().into_owned()), + expr: Expression::MapField { + field: sf.prisma_name().into_owned(), + records: Box::new(Expression::Get { name: "@parent".into() }), + }, + }) + .collect(), + expr: Box::new(Expression::Join { + parent: Box::new(Expression::Get { name: "@parent".into() }), + children: rq + .nested + .into_iter() + .filter_map(|nested| match nested { + ReadQuery::RelatedRecordsQuery(rrq) => Some(rrq), + _ => None, + }) + .map(|rrq| -> TranslateResult { + let parent_fields = rrq.parent_field.linking_fields(); + let child_fields = rrq.parent_field.related_field().linking_fields(); + + let join_expr = parent_fields + .scalars() + .zip(child_fields.scalars()) + .map(|(left, right)| (left.name().to_owned(), right.name().to_owned())) + .collect_vec(); + + // nested.add_filter(Filter::Scalar(ScalarFilter { + // mode: QueryMode::Default, + // condition: ScalarCondition::Equals(ConditionValue::value(PrismaValue::placeholder( + // "parent_id".into(), + // PlaceholderType::String, + // ))), + // projection: ScalarProjection::Compound(referenced_fields), + // })); + let child_query = translate_read_query(ReadQuery::RelatedRecordsQuery(rrq), ctx)?; + + Ok(JoinExpression { + child: child_query, + on: join_expr, + }) + }) + .try_collect()?, + }), + }), + } } ReadQuery::ManyRecordsQuery(mrq) => { @@ -61,7 +134,7 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans } } - _ => unimplemented!(), + _ => todo!(), }) } @@ -69,6 +142,37 @@ fn build_read_m2m_query(_query: RelatedRecordsQuery, _ctx: &Context<'_>) -> Tran todo!() } -fn build_read_one2m_query(_query: RelatedRecordsQuery, _ctx: &Context<'_>) -> TranslateResult { - todo!() +fn build_read_one2m_query(rrq: RelatedRecordsQuery, ctx: &Context<'_>) -> TranslateResult { + let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last(); + let needs_reversed_order = rrq.args.needs_reversed_order(); + + // TODO: we ignore chunking for now + let query = query_builder::read::get_records( + &rrq.parent_field.related_model(), + ModelProjection::from(&selected_fields) + .as_columns(ctx) + .mark_all_selected(), + selected_fields.virtuals(), + rrq.args, + ctx, + ); + + let expr = Expression::Query(build_db_query(query)?); + + if needs_reversed_order { + Ok(Expression::Reverse(Box::new(expr))) + } else { + Ok(expr) + } +} + +fn collect_referenced_fields(nested_queries: &[ReadQuery]) -> HashSet { + nested_queries + .iter() + .filter_map(|rq| match rq { + ReadQuery::RelatedRecordsQuery(rrq) => Some(rrq), + _ => None, + }) + .flat_map(|rrq| rrq.parent_field.referenced_fields()) + .collect() } diff --git a/query-engine/core/src/query_ast/read.rs b/query-engine/core/src/query_ast/read.rs index e3eca8c88ee5..6d25f8f435ff 100644 --- a/query-engine/core/src/query_ast/read.rs +++ b/query-engine/core/src/query_ast/read.rs @@ -64,6 +64,22 @@ impl ReadQuery { ReadQuery::AggregateRecordsQuery(_) => false, } } + + fn nested(&self) -> &[ReadQuery] { + match self { + ReadQuery::RecordQuery(x) => &x.nested, + ReadQuery::ManyRecordsQuery(x) => &x.nested, + ReadQuery::RelatedRecordsQuery(x) => &x.nested, + ReadQuery::AggregateRecordsQuery(_) => &[], + } + } + + pub fn nested_related_records_queries(&self) -> impl Iterator + '_ { + self.nested().iter().filter_map(|q| match q { + ReadQuery::RelatedRecordsQuery(rrq) => Some(rrq), + _ => None, + }) + } } impl FilteredQuery for ReadQuery { diff --git a/query-engine/query-engine/examples/compiler.rs b/query-engine/query-engine/examples/compiler.rs index 7a1150cc3651..950fd99f255e 100644 --- a/query-engine/query-engine/examples/compiler.rs +++ b/query-engine/query-engine/examples/compiler.rs @@ -15,14 +15,18 @@ pub fn main() -> anyhow::Result<()> { let schema = Arc::new(schema); let query_schema = Arc::new(query_core::schema::build(schema, true)); - // prisma.user.findMany({ + // prisma.user.findUnique({ // where: { // email: Prisma.Param("userEmail") + // }, + // select: { + // val: true, + // posts: true, // } // }) let query: JsonSingleQuery = serde_json::from_value(json!({ "modelName": "User", - "action": "findMany", + "action": "findUnique", "query": { "arguments": { "where": { @@ -33,7 +37,7 @@ pub fn main() -> anyhow::Result<()> { } }, "selection": { - "$scalars": true, + "val": true, "posts": { "arguments": {}, "selection": {