diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/mod.rs index a34936908cf1..e1342917d890 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/mod.rs @@ -11,4 +11,5 @@ mod insert_null_in_required_field; mod non_embedded_upsert; mod update; mod update_many; +mod update_many_and_return; mod upsert; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many_and_return.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many_and_return.rs new file mode 100644 index 000000000000..1dad962e83de --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many_and_return.rs @@ -0,0 +1,4 @@ +use query_engine_tests::*; + +#[test_suite(capabilities(UpdateMany, UpdateReturning))] +mod update_many_and_return {} diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs index aeedfc236099..b5aaea79d9b2 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs @@ -114,6 +114,17 @@ impl WriteOperations for MongoDbConnection { .await } + async fn update_records_returning( + &mut self, + _model: &Model, + _record_filter: connector_interface::RecordFilter, + _args: WriteArgs, + _selected_fields: FieldSelection, + _traceparent: Option, + ) -> connector_interface::Result { + unimplemented!() + } + async fn update_record( &mut self, model: &Model, diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index ae6546edb7f1..0ba1350d92f6 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -144,6 +144,17 @@ impl WriteOperations for MongoDbTransaction<'_> { .await } + async fn update_records_returning( + &mut self, + _model: &Model, + _record_filter: connector_interface::RecordFilter, + _args: connector_interface::WriteArgs, + _selected_fields: FieldSelection, + _traceparent: Option, + ) -> connector_interface::Result { + unimplemented!() + } + async fn update_record( &mut self, model: &Model, diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index 6c7c003903df..c196c76bc309 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -289,6 +289,19 @@ pub trait WriteOperations { traceparent: Option, ) -> crate::Result; + /// Updates many records at once into the database and returns their + /// selected fields. + /// This method should not be used if the connector does not support + /// returning updated rows. + async fn update_records_returning( + &mut self, + model: &Model, + record_filter: RecordFilter, + args: WriteArgs, + selected_fields: FieldSelection, + traceparent: Option, + ) -> crate::Result; + /// Update record in the `Model` with the given `WriteArgs` filtered by the /// `Filter`. async fn update_record( diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 924af5ec12e4..af0a15192c1a 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -236,6 +236,22 @@ where .await } + async fn update_records_returning( + &mut self, + model: &Model, + record_filter: RecordFilter, + args: WriteArgs, + selected_fields: FieldSelection, + traceparent: Option, + ) -> connector::Result { + let ctx = Context::new(&self.connection_info, traceparent); + catch( + &self.connection_info, + write::update_records_returning(&self.inner, model, record_filter, args, selected_fields, &ctx), + ) + .await + } + async fn update_record( &mut self, model: &Model, diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs index 0dc8081f97d2..d0845ed83451 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs @@ -2,12 +2,14 @@ use super::read::get_single_record; use crate::column_metadata::{self, ColumnMetadata}; use crate::filter::FilterBuilder; +use crate::model_extensions::AsColumns; use crate::query_builder::write::{build_update_and_set_query, chunk_update_with_ids}; use crate::row::ToSqlRow; use crate::{Context, QueryExt, Queryable}; use connector_interface::*; use itertools::Itertools; +use quaint::ast::Query; use query_structure::*; /// Performs an update with an explicit selection set. @@ -77,7 +79,7 @@ pub(crate) async fn update_one_without_selection( let id_args = pick_args(&model.primary_identifier().into(), &args); // Perform the update and return the ids on which we've applied the update. // Note: We are _not_ getting back the ids from the update. Either we got some ids passed from the parent operation or we perform a read _before_ doing the update. - let (_, ids) = update_many_from_ids_and_filter(conn, model, record_filter, args, ctx).await?; + let (_, ids) = update_many_from_ids_and_filter(conn, model, record_filter, args, None, ctx).await?; // Since we could not get the ids back from the update, we need to apply in-memory transformation to the ids in case they were part of the update. // This is critical to ensure the following operations can operate on the updated ids. let merged_ids = merge_write_args(ids, id_args); @@ -92,53 +94,50 @@ pub(crate) async fn update_one_without_selection( // Generates a query like this: // UPDATE "public"."User" SET "name" = $1 WHERE "public"."User"."age" > $1 -pub(crate) async fn update_many_from_filter( - conn: &dyn Queryable, +pub(super) async fn update_many_from_filter( model: &Model, record_filter: RecordFilter, args: WriteArgs, + selected_fields: Option<&ModelProjection>, ctx: &Context<'_>, -) -> crate::Result { +) -> crate::Result> { let update = build_update_and_set_query(model, args, None, ctx); let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter, ctx); let update = update.so_that(filter_condition); - let count = conn.execute(update.into()).await?; - - Ok(count as usize) + if let Some(selected_fields) = selected_fields { + Ok(update + .returning(selected_fields.as_columns(ctx).map(|c| c.set_is_selected(true))) + .into()) + } else { + Ok(update.into()) + } } // Generates a query like this: // UPDATE "public"."User" SET "name" = $1 WHERE "public"."User"."id" IN ($2,$3,$4,$5,$6,$7,$8,$9,$10,$11) AND "public"."User"."age" > $1 -pub(crate) async fn update_many_from_ids_and_filter( +pub(super) async fn update_many_from_ids_and_filter( conn: &dyn Queryable, model: &Model, record_filter: RecordFilter, args: WriteArgs, + selected_fields: Option<&ModelProjection>, ctx: &Context<'_>, -) -> crate::Result<(usize, Vec)> { +) -> crate::Result<(Vec>, Vec)> { let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter.clone(), ctx); let ids: Vec = conn.filter_selectors(model, record_filter, ctx).await?; if ids.is_empty() { - return Ok((0, Vec::new())); + return Ok((vec![], Vec::new())); } let updates = { - let update = build_update_and_set_query(model, args, None, ctx); + let update = build_update_and_set_query(model, args, selected_fields, ctx); let ids: Vec<&SelectionResult> = ids.iter().collect(); chunk_update_with_ids(update, model, &ids, filter_condition, ctx)? }; - let mut count = 0; - - for update in updates { - let update_count = conn.execute(update).await?; - - count += update_count; - } - - Ok((count as usize, ids)) + Ok((updates, ids)) } fn process_result_row( diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 137bff50ca58..ed200af58223 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -8,7 +8,7 @@ use crate::{ }; use connector_interface::*; use itertools::Itertools; -use quaint::ast::Insert; +use quaint::ast::{Insert, Query}; use quaint::{ error::ErrorKind, prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped, Aliasable, Select, SqlFamily}, @@ -370,6 +370,25 @@ pub(crate) async fn update_record( } } +async fn generate_updates( + conn: &dyn Queryable, + model: &Model, + record_filter: RecordFilter, + args: WriteArgs, + selected_fields: Option<&ModelProjection>, + ctx: &Context<'_>, +) -> crate::Result>> { + if record_filter.has_selectors() { + let (updates, _) = + update_many_from_ids_and_filter(conn, model, record_filter, args, selected_fields, ctx).await?; + Ok(updates) + } else { + Ok(vec![ + update_many_from_filter(model, record_filter, args, selected_fields, ctx).await?, + ]) + } +} + /// Update multiple records in a database defined in `conn` and the records /// defined in `args`, and returning the number of updates /// This works via two ways, when there are ids in record_filter.selectors, it uses that to update @@ -385,15 +404,42 @@ pub(crate) async fn update_records( return Ok(0); } - if record_filter.has_selectors() { - let (count, _) = update_many_from_ids_and_filter(conn, model, record_filter, args, ctx).await?; + let mut count = 0; + for update in generate_updates(conn, model, record_filter, args, None, ctx).await? { + count += conn.execute(update).await?; + } + Ok(count as usize) +} - Ok(count) - } else { - let count = update_many_from_filter(conn, model, record_filter, args, ctx).await?; +/// Update records according to `WriteArgs`. Returns values of fields specified in +/// `selected_fields` for all updated rows. +pub(crate) async fn update_records_returning( + conn: &dyn Queryable, + model: &Model, + record_filter: RecordFilter, + args: WriteArgs, + selected_fields: FieldSelection, + ctx: &Context<'_>, +) -> crate::Result { + let field_names: Vec = selected_fields.db_names().collect(); + let idents = selected_fields.type_identifiers_with_arities(); + let meta = column_metadata::create(&field_names, &idents); + let mut records = ManyRecords::new(field_names.clone()); + + let updates = generate_updates(conn, model, record_filter, args, Some(&selected_fields.into()), ctx).await?; - Ok(count) + for update in updates { + let result_set = conn.query(update).await?; + + for result_row in result_set { + let sql_row = result_row.to_sql_row(&meta)?; + let record = Record::from(sql_row); + + records.push(record); + } } + + Ok(records) } /// Delete multiple records in `conn`, defined in the `Filter`. Result is the number of items deleted. diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index eea377ad5a57..9a3429ab6e2c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -230,6 +230,29 @@ impl WriteOperations for SqlConnectorTransaction<'_> { .await } + async fn update_records_returning( + &mut self, + model: &Model, + record_filter: RecordFilter, + args: WriteArgs, + selected_fields: FieldSelection, + traceparent: Option, + ) -> connector::Result { + let ctx = Context::new(&self.connection_info, traceparent); + catch( + &self.connection_info, + write::update_records_returning( + self.inner.as_queryable(), + model, + record_filter, + args, + selected_fields, + &ctx, + ), + ) + .await + } + async fn update_record( &mut self, model: &Model, diff --git a/query-engine/core/src/interpreter/query_interpreters/write.rs b/query-engine/core/src/interpreter/query_interpreters/write.rs index 50096ed93392..065f5e6eebe3 100644 --- a/query-engine/core/src/interpreter/query_interpreters/write.rs +++ b/query-engine/core/src/interpreter/query_interpreters/write.rs @@ -305,11 +305,31 @@ async fn update_many( q: UpdateManyRecords, traceparent: Option, ) -> InterpretationResult { - let res = tx - .update_records(&q.model, q.record_filter, q.args, traceparent) - .await?; + if let Some(selected_fields) = q.selected_fields { + let records = tx + .update_records_returning(&q.model, q.record_filter, q.args, selected_fields.fields, traceparent) + .await?; - Ok(QueryResult::Count(res)) + let nested: Vec = + super::read::process_nested(tx, selected_fields.nested, Some(&records), traceparent).await?; + + let selection = RecordSelection { + name: q.name, + fields: selected_fields.order, + records, + nested, + model: q.model, + virtual_fields: vec![], + }; + + Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) + } else { + let affected_records = tx + .update_records(&q.model, q.record_filter, q.args, traceparent) + .await?; + + Ok(QueryResult::Count(affected_records)) + } } async fn delete_many( diff --git a/query-engine/core/src/query_ast/write.rs b/query-engine/core/src/query_ast/write.rs index 940b36cc57cf..9d30e3ff7b5c 100644 --- a/query-engine/core/src/query_ast/write.rs +++ b/query-engine/core/src/query_ast/write.rs @@ -361,9 +361,20 @@ pub struct UpdateRecordWithoutSelection { #[derive(Debug, Clone)] pub struct UpdateManyRecords { + pub name: String, pub model: Model, pub record_filter: RecordFilter, pub args: WriteArgs, + /// Fields of updated records that client has requested to return. + /// `None` if the connector does not support returning the updated rows. + pub selected_fields: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateManyRecordsFields { + pub fields: FieldSelection, + pub order: Vec, + pub nested: Vec, } #[derive(Debug, Clone)] diff --git a/query-engine/core/src/query_graph_builder/builder.rs b/query-engine/core/src/query_graph_builder/builder.rs index 8ba2929caccd..1cce2f86edfd 100644 --- a/query-engine/core/src/query_graph_builder/builder.rs +++ b/query-engine/core/src/query_graph_builder/builder.rs @@ -116,7 +116,8 @@ impl<'a> QueryGraphBuilder<'a> { (QueryTag::CreateMany, Some(m)) => QueryGraph::root(|g| write::create_many_records(g, query_schema, m, false, parsed_field)), (QueryTag::CreateManyAndReturn, Some(m)) => QueryGraph::root(|g| write::create_many_records(g, query_schema, m, true, parsed_field)), (QueryTag::UpdateOne, Some(m)) => QueryGraph::root(|g| write::update_record(g, query_schema, m, parsed_field)), - (QueryTag::UpdateMany, Some(m)) => QueryGraph::root(|g| write::update_many_records(g, query_schema, m, parsed_field)), + (QueryTag::UpdateMany, Some(m)) => QueryGraph::root(|g| write::update_many_records(g, query_schema, m, false, parsed_field)), + (QueryTag::UpdateManyAndReturn, Some(m)) => QueryGraph::root(|g| write::update_many_records(g, query_schema, m, true, parsed_field)), (QueryTag::UpsertOne, Some(m)) => QueryGraph::root(|g| write::upsert_record(g, query_schema, m, parsed_field)), (QueryTag::DeleteOne, Some(m)) => QueryGraph::root(|g| write::delete_record(g, query_schema, m, parsed_field)), (QueryTag::DeleteMany, Some(m)) => QueryGraph::root(|g| write::delete_many_records(g, query_schema, m, parsed_field)), diff --git a/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs b/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs index 735bd0a88c31..4628912a4a46 100644 --- a/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs +++ b/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs @@ -142,8 +142,15 @@ pub fn nested_update_many( let find_child_records_node = utils::insert_find_children_by_parent_node(graph, parent, parent_relation_field, filter)?; - let update_many_node = - update::update_many_record_node(graph, query_schema, Filter::empty(), child_model.clone(), data_map)?; + let update_many_node = update::update_many_record_node( + graph, + query_schema, + Filter::empty(), + child_model.clone(), + None, + None, + data_map, + )?; graph.create_edge( &find_child_records_node, diff --git a/query-engine/core/src/query_graph_builder/write/update.rs b/query-engine/core/src/query_graph_builder/write/update.rs index f112fbaa5d54..2f9024d69320 100644 --- a/query-engine/core/src/query_graph_builder/write/update.rs +++ b/query-engine/core/src/query_graph_builder/write/update.rs @@ -1,5 +1,6 @@ use super::*; use crate::query_graph_builder::write::write_args_parser::*; +use crate::ParsedObject; use crate::{ query_ast::*, query_graph::{Node, NodeRef, QueryGraph, QueryGraphDependency}, @@ -124,6 +125,7 @@ pub fn update_many_records( graph: &mut QueryGraph, query_schema: &QuerySchema, model: Model, + with_field_selection: bool, mut field: ParsedField<'_>, ) -> QueryGraphBuilderResult<()> { graph.flag_transactional(); @@ -139,14 +141,30 @@ pub fn update_many_records( let data_map: ParsedInputMap<'_> = data_argument.value.try_into()?; if query_schema.relation_mode().uses_foreign_keys() { - update_many_record_node(graph, query_schema, filter, model, data_map)?; + update_many_record_node( + graph, + query_schema, + filter, + model, + Some(field.name), + field.nested_fields.filter(|_| with_field_selection), + data_map, + )?; } else { let pre_read_node = graph.create_node(utils::read_ids_infallible( model.clone(), model.primary_identifier(), filter, )); - let update_many_node = update_many_record_node(graph, query_schema, Filter::empty(), model.clone(), data_map)?; + let update_many_node = update_many_record_node( + graph, + query_schema, + Filter::empty(), + model.clone(), + Some(field.name), + field.nested_fields.filter(|_| with_field_selection), + data_map, + )?; utils::insert_emulated_on_update(graph, query_schema, &model, &pre_read_node, &update_many_node)?; @@ -249,6 +267,8 @@ pub fn update_many_record_node( query_schema: &QuerySchema, filter: T, model: Model, + name: Option, + nested_field_selection: Option>, data_map: ParsedInputMap<'_>, ) -> QueryGraphBuilderResult where @@ -263,10 +283,25 @@ where args.update_datetimes(&model); + let selected_fields = if let Some(nested_fields) = nested_field_selection { + let (selected_fields, selection_order, nested_read) = + super::read::utils::extract_selected_fields(nested_fields.fields, &model, query_schema)?; + + Some(UpdateManyRecordsFields { + fields: selected_fields, + order: selection_order, + nested: nested_read, + }) + } else { + None + }; + let update_many = UpdateManyRecords { + name: name.unwrap_or_default(), model, record_filter, args, + selected_fields, }; let update_many_node = graph.create_node(Query::Write(WriteQuery::UpdateManyRecords(update_many))); diff --git a/query-engine/core/src/query_graph_builder/write/utils.rs b/query-engine/core/src/query_graph_builder/write/utils.rs index a931fa1edc30..d9f89feade62 100644 --- a/query-engine/core/src/query_graph_builder/write/utils.rs +++ b/query-engine/core/src/query_graph_builder/write/utils.rs @@ -224,9 +224,11 @@ where let record_filter = filter.into(); let ur = UpdateManyRecords { + name: String::new(), model, record_filter, args, + selected_fields: None, }; graph.create_node(Query::Write(WriteQuery::UpdateManyRecords(ur))) @@ -610,9 +612,11 @@ pub fn emulate_on_delete_set_null( insert_find_children_by_parent_node(graph, node_providing_ids, &parent_relation_field, Filter::empty())?; let set_null_query = WriteQuery::UpdateManyRecords(UpdateManyRecords { + name: String::new(), model: dependent_model.clone(), record_filter: RecordFilter::empty(), args: WriteArgs::new(child_update_args, crate::executor::get_request_now()), + selected_fields: None, }); let set_null_dependents_node = graph.create_node(Query::Write(set_null_query)); @@ -760,9 +764,11 @@ pub fn emulate_on_update_set_null( insert_find_children_by_parent_node(graph, parent_node, &parent_relation_field, Filter::empty())?; let set_null_query = WriteQuery::UpdateManyRecords(UpdateManyRecords { + name: String::new(), model: dependent_model.clone(), record_filter: RecordFilter::empty(), args: WriteArgs::new(child_update_args, crate::executor::get_request_now()), + selected_fields: None, }); let set_null_dependents_node = graph.create_node(Query::Write(set_null_query)); @@ -1080,12 +1086,14 @@ pub fn emulate_on_update_cascade( insert_find_children_by_parent_node(graph, parent_node, &parent_relation_field, Filter::empty())?; let update_query = WriteQuery::UpdateManyRecords(UpdateManyRecords { + name: String::new(), model: dependent_model.clone(), record_filter: RecordFilter::empty(), args: WriteArgs::new( child_update_args.into_iter().collect(), crate::executor::get_request_now(), ), + selected_fields: None, }); let update_dependents_node = graph.create_node(Query::Write(update_query)); diff --git a/query-engine/schema/src/query_schema.rs b/query-engine/schema/src/query_schema.rs index f8fede0ea355..094b4bcdd600 100644 --- a/query-engine/schema/src/query_schema.rs +++ b/query-engine/schema/src/query_schema.rs @@ -247,6 +247,7 @@ pub enum QueryTag { CreateManyAndReturn, UpdateOne, UpdateMany, + UpdateManyAndReturn, DeleteOne, DeleteMany, UpsertOne, @@ -273,6 +274,7 @@ impl fmt::Display for QueryTag { Self::CreateManyAndReturn => "createManyAndReturn", Self::UpdateOne => "updateOne", Self::UpdateMany => "updateMany", + Self::UpdateManyAndReturn => "updateManyAndReturn", Self::DeleteOne => "deleteOne", Self::DeleteMany => "deleteMany", Self::UpsertOne => "upsertOne",