diff --git a/prisma-fmt/src/get_dmmf.rs b/prisma-fmt/src/get_dmmf.rs index 0b627a7e545..3733c36d455 100644 --- a/prisma-fmt/src/get_dmmf.rs +++ b/prisma-fmt/src/get_dmmf.rs @@ -5523,6 +5523,18 @@ mod tests { "isList": false } ] + }, + { + "name": "limit", + "isRequired": false, + "isNullable": false, + "inputTypes": [ + { + "type": "Int", + "location": "scalar", + "isList": false + } + ] } ], "isNullable": false, @@ -5567,6 +5579,18 @@ mod tests { "isList": false } ] + }, + { + "name": "limit", + "isRequired": false, + "isNullable": false, + "inputTypes": [ + { + "type": "Int", + "location": "scalar", + "isList": false + } + ] } ], "isNullable": false, @@ -5897,6 +5921,18 @@ mod tests { "isList": false } ] + }, + { + "name": "limit", + "isRequired": false, + "isNullable": false, + "inputTypes": [ + { + "type": "Int", + "location": "scalar", + "isList": false + } + ] } ], "isNullable": false, @@ -5941,6 +5977,18 @@ mod tests { "isList": false } ] + }, + { + "name": "limit", + "isRequired": false, + "isNullable": false, + "inputTypes": [ + { + "type": "Int", + "location": "scalar", + "isList": false + } + ] } ], "isNullable": false, diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many.rs index e4b264e8bb3..50bbe195520 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many.rs @@ -202,6 +202,24 @@ mod delete_many { Ok(()) } + // "The delete many Mutation" should "fail if limit param is negative" + #[connector_test] + async fn should_fail_with_negative_limit(runner: Runner) -> TestResult<()> { + create_row(&runner, r#"{ id: 1, title: "title1" }"#).await?; + create_row(&runner, r#"{ id: 2, title: "title2" }"#).await?; + create_row(&runner, r#"{ id: 3, title: "title3" }"#).await?; + create_row(&runner, r#"{ id: 4, title: "title4" }"#).await?; + + assert_error!( + &runner, + r#"mutation { deleteManyTodo(limit: -3){ count }}"#, + 2019, + "Provided limit (-3) must be a positive integer." + ); + + Ok(()) + } + fn nested_del_many() -> String { let schema = indoc! { r#"model ZChild{ diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs index c9a6d6df567..0b07790645c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs @@ -122,6 +122,65 @@ mod update_many { Ok(()) } + // "An updateMany mutation" should "update max limit number of items" + #[connector_test] + async fn update_max_limit_items(runner: Runner) -> TestResult<()> { + create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; + create_row(&runner, r#"{ id: 2, optStr: "str2" }"#).await?; + create_row(&runner, r#"{ id: 3, optStr: "str3" }"#).await?; + + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { + updateManyTestModel( + where: { } + data: { optStr: { set: "updated" } } + limit: 2 + ){ + count + } + }"#), + @r###"{"data":{"updateManyTestModel":{"count":2}}}"### + ); + + insta::assert_snapshot!( + run_query!( + &runner, + r#"{ + findManyTestModel(orderBy: { id: asc }) { + optStr + } + }"#), + @r###"{"data":{"findManyTestModel":[{"optStr":"updated"},{"optStr":"updated"},{"optStr":"str3"}]}}"### + ); + + Ok(()) + } + + // "An updateMany mutation" should "fail if limit param is negative" + #[connector_test] + async fn should_fail_with_negative_limit(runner: Runner) -> TestResult<()> { + create_row(&runner, r#"{ id: 1, optStr: "str1" }"#).await?; + create_row(&runner, r#"{ id: 2, optStr: "str2" }"#).await?; + create_row(&runner, r#"{ id: 3, optStr: "str3" }"#).await?; + + assert_error!( + &runner, + r#"mutation { + updateManyTestModel( + where: { } + data: { optStr: { set: "updated" } } + limit: -2 + ){ + count + } + }"#, + 2019, + "Provided limit (-2) must be a positive integer." + ); + + Ok(()) + } + // "An updateMany mutation" should "correctly apply all number operations for Int" #[connector_test(exclude(CockroachDb))] async fn apply_number_ops_for_int(runner: Runner) -> TestResult<()> { diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs index a0c49fca73a..94f29b21535 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs @@ -96,6 +96,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, record_filter: connector_interface::RecordFilter, args: WriteArgs, + limit: Option, _traceparent: Option, ) -> connector_interface::Result { catch(async move { @@ -105,7 +106,7 @@ impl WriteOperations for MongoDbConnection { model, record_filter, args, - UpdateType::Many, + UpdateType::Many { limit }, ) .await?; @@ -120,6 +121,7 @@ impl WriteOperations for MongoDbConnection { _record_filter: connector_interface::RecordFilter, _args: WriteArgs, _selected_fields: FieldSelection, + _limit: Option, _traceparent: Option, ) -> connector_interface::Result { unimplemented!() @@ -162,7 +164,7 @@ impl WriteOperations for MongoDbConnection { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - limit: Option, + limit: Option, _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 00115b5cfba..31943e0dd6c 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -127,6 +127,7 @@ impl WriteOperations for MongoDbTransaction<'_> { model: &Model, record_filter: connector_interface::RecordFilter, args: connector_interface::WriteArgs, + limit: Option, _traceparent: Option, ) -> connector_interface::Result { catch(async move { @@ -136,7 +137,7 @@ impl WriteOperations for MongoDbTransaction<'_> { model, record_filter, args, - UpdateType::Many, + UpdateType::Many { limit }, ) .await?; Ok(result.len()) @@ -150,6 +151,7 @@ impl WriteOperations for MongoDbTransaction<'_> { _record_filter: connector_interface::RecordFilter, _args: connector_interface::WriteArgs, _selected_fields: FieldSelection, + _limit: Option, _traceparent: Option, ) -> connector_interface::Result { unimplemented!() @@ -191,7 +193,7 @@ impl WriteOperations for MongoDbTransaction<'_> { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - limit: Option, + limit: Option, _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs index 622583c2a5f..a92c373be86 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs @@ -1,4 +1,5 @@ use super::*; +use crate::error::MongoError::ConversionError; use crate::{ error::{DecorateErrorWithFieldInformationExtension, MongoError}, filter::{FilterPrefix, MongoFilter, MongoFilterVisitor}, @@ -160,6 +161,10 @@ pub async fn update_records<'conn>( let ids: Vec = if let Some(selectors) = record_filter.selectors { selectors .into_iter() + .take(match update_type { + UpdateType::Many { limit } => limit.unwrap_or(usize::MAX), + UpdateType::One => 1, + }) .map(|p| { (&id_field, p.values().next().unwrap()) .into_bson() @@ -205,7 +210,7 @@ pub async fn update_records<'conn>( // It's important we check the `matched_count` and not the `modified_count` here. // MongoDB returns `modified_count: 0` when performing a noop update, which breaks // nested connect mutations as it rely on the returned count to know whether the update happened. - if update_type == UpdateType::Many && res.matched_count == 0 { + if matches!(update_type, UpdateType::Many { limit: _ }) && res.matched_count == 0 { return Ok(Vec::new()); } } @@ -228,7 +233,7 @@ pub async fn delete_records<'conn>( session: &mut ClientSession, model: &Model, record_filter: RecordFilter, - limit: Option, + limit: Option, ) -> crate::Result { let coll = database.collection::(model.db_name()); let id_field = pick_singular_id(model); @@ -236,7 +241,7 @@ pub async fn delete_records<'conn>( let ids = if let Some(selectors) = record_filter.selectors { selectors .into_iter() - .take(limit.unwrap_or(i64::MAX) as usize) + .take(limit.unwrap_or(usize::MAX)) .map(|p| { (&id_field, p.values().next().unwrap()) .into_bson() @@ -305,7 +310,7 @@ async fn find_ids( session: &mut ClientSession, model: &Model, filter: MongoFilter, - limit: Option, + limit: Option, ) -> crate::Result> { let id_field = model.primary_identifier(); let mut builder = MongoReadQueryBuilder::new(model.clone()); @@ -321,7 +326,17 @@ async fn find_ids( let mut builder = builder.with_model_projection(id_field)?; - builder.limit = limit; + if let Some(limit) = limit { + builder.limit = match i64::try_from(limit) { + Ok(limit) => Some(limit), + Err(_) => { + return Err(ConversionError { + from: "usize".to_owned(), + to: "i64".to_owned(), + }) + } + } + } let query = builder.build()?; let docs = query.execute(collection, session).await?; diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index bfc6c5d725a..3bf1614e039 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -286,6 +286,7 @@ pub trait WriteOperations { model: &Model, record_filter: RecordFilter, args: WriteArgs, + limit: Option, traceparent: Option, ) -> crate::Result; @@ -299,6 +300,7 @@ pub trait WriteOperations { record_filter: RecordFilter, args: WriteArgs, selected_fields: FieldSelection, + limit: Option, traceparent: Option, ) -> crate::Result; @@ -326,7 +328,7 @@ pub trait WriteOperations { &mut self, model: &Model, record_filter: RecordFilter, - limit: Option, + limit: Option, traceparent: Option, ) -> crate::Result; diff --git a/query-engine/connectors/query-connector/src/lib.rs b/query-engine/connectors/query-connector/src/lib.rs index 5488dfaef49..c497f121ae9 100644 --- a/query-engine/connectors/query-connector/src/lib.rs +++ b/query-engine/connectors/query-connector/src/lib.rs @@ -20,6 +20,6 @@ pub type Result = std::result::Result; /// However when we updating any records we want to return an empty array if zero items were updated #[derive(PartialEq)] pub enum UpdateType { - Many, + Many { limit: Option }, One, } diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index e47fb9ce019..614f174e562 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -226,12 +226,13 @@ where model: &Model, record_filter: RecordFilter, args: WriteArgs, + limit: Option, traceparent: Option, ) -> connector::Result { let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, - write::update_records(&self.inner, model, record_filter, args, &ctx), + write::update_records(&self.inner, model, record_filter, args, limit, &ctx), ) .await } @@ -242,12 +243,13 @@ where record_filter: RecordFilter, args: WriteArgs, selected_fields: FieldSelection, + limit: Option, traceparent: Option, ) -> connector::Result { let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, - write::update_records_returning(&self.inner, model, record_filter, args, selected_fields, &ctx), + write::update_records_returning(&self.inner, model, record_filter, args, selected_fields, limit, &ctx), ) .await } @@ -272,7 +274,7 @@ where &mut self, model: &Model, record_filter: RecordFilter, - limit: Option, + limit: Option, traceparent: Option, ) -> connector::Result { let ctx = Context::new(&self.connection_info, traceparent); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs index 4773853ea25..9ea13127a4f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs @@ -7,9 +7,10 @@ use crate::query_builder::write::{build_update_and_set_query, chunk_update_with_ use crate::row::ToSqlRow; use crate::{Context, QueryExt, Queryable}; +use crate::limit::wrap_with_limit_subquery_if_needed; use connector_interface::*; use itertools::Itertools; -use quaint::ast::Query; +use quaint::ast::*; use query_structure::*; /// Performs an update with an explicit selection set. @@ -79,7 +80,7 @@ pub(crate) async fn update_one_without_selection( let id_args = pick_args(&model.primary_identifier().into(), &args); // Perform the update and return the ids on which we've applied the update. // Note: We are _not_ getting back the ids from the update. Either we got some ids passed from the parent operation or we perform a read _before_ doing the update. - let (updates, ids) = update_many_from_ids_and_filter(conn, model, record_filter, args, None, ctx).await?; + let (updates, ids) = update_many_from_ids_and_filter(conn, model, record_filter, args, None, None, ctx).await?; for update in updates { conn.execute(update).await?; } @@ -103,10 +104,17 @@ pub(super) async fn update_many_from_filter( record_filter: RecordFilter, args: WriteArgs, selected_fields: Option<&ModelProjection>, + limit: Option, ctx: &Context<'_>, ) -> crate::Result> { let update = build_update_and_set_query(model, args, None, ctx); - let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter, ctx); + let filter_condition = wrap_with_limit_subquery_if_needed( + model, + FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter, ctx), + limit, + ctx, + ); + let update = update.so_that(filter_condition); if let Some(selected_fields) = selected_fields { Ok(update @@ -125,6 +133,7 @@ pub(super) async fn update_many_from_ids_and_filter( record_filter: RecordFilter, args: WriteArgs, selected_fields: Option<&ModelProjection>, + limit: Option, ctx: &Context<'_>, ) -> crate::Result<(Vec>, Vec)> { let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter.clone(), ctx); @@ -136,7 +145,7 @@ pub(super) async fn update_many_from_ids_and_filter( let updates = { let update = build_update_and_set_query(model, args, selected_fields, ctx); - let ids: Vec<&SelectionResult> = ids.iter().collect(); + let ids: Vec<&SelectionResult> = ids.iter().take(limit.unwrap_or(usize::MAX)).collect(); chunk_update_with_ids(update, model, &ids, filter_condition, ctx)? }; diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 1afa8aea8b8..19a52fa94fb 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -376,15 +376,16 @@ async fn generate_updates( record_filter: RecordFilter, args: WriteArgs, selected_fields: Option<&ModelProjection>, + limit: Option, ctx: &Context<'_>, ) -> crate::Result>> { if record_filter.has_selectors() { let (updates, _) = - update_many_from_ids_and_filter(conn, model, record_filter, args, selected_fields, ctx).await?; + update_many_from_ids_and_filter(conn, model, record_filter, args, selected_fields, limit, ctx).await?; Ok(updates) } else { Ok(vec![ - update_many_from_filter(model, record_filter, args, selected_fields, ctx).await?, + update_many_from_filter(model, record_filter, args, selected_fields, limit, ctx).await?, ]) } } @@ -398,6 +399,7 @@ pub(crate) async fn update_records( model: &Model, record_filter: RecordFilter, args: WriteArgs, + limit: Option, ctx: &Context<'_>, ) -> crate::Result { if args.args.is_empty() { @@ -405,7 +407,7 @@ pub(crate) async fn update_records( } let mut count = 0; - for update in generate_updates(conn, model, record_filter, args, None, ctx).await? { + for update in generate_updates(conn, model, record_filter, args, None, limit, ctx).await? { count += conn.execute(update).await?; } Ok(count as usize) @@ -419,6 +421,7 @@ pub(crate) async fn update_records_returning( record_filter: RecordFilter, args: WriteArgs, selected_fields: FieldSelection, + limit: Option, ctx: &Context<'_>, ) -> crate::Result { let field_names: Vec = selected_fields.db_names().collect(); @@ -426,7 +429,17 @@ pub(crate) async fn update_records_returning( let meta = column_metadata::create(&field_names, &idents); let mut records = ManyRecords::new(field_names.clone()); - for update in generate_updates(conn, model, record_filter, args, Some(&selected_fields.into()), ctx).await? { + for update in generate_updates( + conn, + model, + record_filter, + args, + Some(&selected_fields.into()), + limit, + ctx, + ) + .await? + { let result_set = conn.query(update).await?; for result_row in result_set { @@ -445,7 +458,7 @@ pub(crate) async fn delete_records( conn: &dyn Queryable, model: &Model, record_filter: RecordFilter, - limit: Option, + limit: Option, ctx: &Context<'_>, ) -> crate::Result { let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.clone().filter, ctx); @@ -461,8 +474,9 @@ pub(crate) async fn delete_records( { row_count += conn.execute(delete).await?; if let Some(old_remaining_limit) = remaining_limit { - let new_remaining_limit = old_remaining_limit - row_count as i64; - if new_remaining_limit <= 0 { + // u64 to usize cast here cannot 'overflow' as the number of rows was limited to MAX usize in the first place. + let new_remaining_limit = old_remaining_limit - row_count as usize; + if new_remaining_limit == 0 { break; } remaining_limit = Some(new_remaining_limit); diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index cdf5a466237..6528343f54f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -220,12 +220,13 @@ impl WriteOperations for SqlConnectorTransaction<'_> { model: &Model, record_filter: RecordFilter, args: WriteArgs, + limit: Option, traceparent: Option, ) -> connector::Result { let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, - write::update_records(self.inner.as_queryable(), model, record_filter, args, &ctx), + write::update_records(self.inner.as_queryable(), model, record_filter, args, limit, &ctx), ) .await } @@ -236,6 +237,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> { record_filter: RecordFilter, args: WriteArgs, selected_fields: FieldSelection, + limit: Option, traceparent: Option, ) -> connector::Result { let ctx = Context::new(&self.connection_info, traceparent); @@ -247,6 +249,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> { record_filter, args, selected_fields, + limit, &ctx, ), ) @@ -280,7 +283,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> { &mut self, model: &Model, record_filter: RecordFilter, - limit: Option, + limit: Option, traceparent: Option, ) -> connector::Result { catch(&self.connection_info, async { diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index 9bd6c2d7f21..2f019b41e54 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -8,6 +8,7 @@ mod database; mod error; mod filter; mod join_utils; +mod limit; mod model_extensions; mod nested_aggregations; mod ordering; diff --git a/query-engine/connectors/sql-query-connector/src/limit.rs b/query-engine/connectors/sql-query-connector/src/limit.rs new file mode 100644 index 00000000000..8df1e749e4b --- /dev/null +++ b/query-engine/connectors/sql-query-connector/src/limit.rs @@ -0,0 +1,31 @@ +use crate::{model_extensions::*, Context}; +use quaint::ast::*; +use query_structure::*; + +pub(crate) fn wrap_with_limit_subquery_if_needed<'a>( + model: &Model, + filter_condition: ConditionTree<'a>, + limit: Option, + ctx: &Context, +) -> ConditionTree<'a> { + if let Some(limit) = limit { + let columns = model + .primary_identifier() + .as_scalar_fields() + .expect("primary identifier must contain scalar fields") + .into_iter() + .map(|f| f.as_column(ctx)) + .collect::>(); + + ConditionTree::from( + Row::from(columns.clone()).in_selection( + Select::from_table(model.as_table(ctx)) + .columns(columns) + .so_that(filter_condition) + .limit(limit), + ), + ) + } else { + filter_condition + } +} diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs index 0ec506e1b6f..2f4ab525e84 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs @@ -1,3 +1,4 @@ +use crate::limit::wrap_with_limit_subquery_if_needed; use crate::{model_extensions::*, sql_trace::SqlTraceComment, Context}; use connector_interface::{DatasourceFieldName, ScalarWriteOperation, WriteArgs}; use quaint::ast::*; @@ -226,32 +227,13 @@ pub(crate) fn delete_returning( pub(crate) fn delete_many_from_filter( model: &Model, filter_condition: ConditionTree<'static>, - limit: Option, + limit: Option, ctx: &Context<'_>, ) -> Query<'static> { - let condition = if let Some(limit) = limit { - let columns = model - .primary_identifier() - .as_scalar_fields() - .expect("primary identifier must contain scalar fields") - .into_iter() - .map(|f| f.as_column(ctx)) - .collect::>(); - - ConditionTree::from( - Row::from(columns.clone()).in_selection( - Select::from_table(model.as_table(ctx)) - .columns(columns) - .so_that(filter_condition) - .limit(limit as usize), - ), - ) - } else { - filter_condition - }; + let filter_condition = wrap_with_limit_subquery_if_needed(model, filter_condition, limit, ctx); Delete::from_table(model.as_table(ctx)) - .so_that(condition) + .so_that(filter_condition) .add_traceparent(ctx.traceparent) .into() } @@ -260,7 +242,7 @@ pub(crate) fn delete_many_from_ids_and_filter( model: &Model, ids: &[&SelectionResult], filter_condition: ConditionTree<'static>, - limit: Option, + limit: Option, ctx: &Context<'_>, ) -> Vec> { let columns: Vec<_> = ModelProjection::from(model.primary_identifier()) diff --git a/query-engine/core/src/interpreter/query_interpreters/write.rs b/query-engine/core/src/interpreter/query_interpreters/write.rs index 083cf238cae..3dcb992b435 100644 --- a/query-engine/core/src/interpreter/query_interpreters/write.rs +++ b/query-engine/core/src/interpreter/query_interpreters/write.rs @@ -307,7 +307,14 @@ async fn update_many( ) -> InterpretationResult { if let Some(selected_fields) = q.selected_fields { let records = tx - .update_records_returning(&q.model, q.record_filter, q.args, selected_fields.fields, traceparent) + .update_records_returning( + &q.model, + q.record_filter, + q.args, + selected_fields.fields, + q.limit, + traceparent, + ) .await?; let nested: Vec = @@ -325,7 +332,7 @@ async fn update_many( Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) } else { let affected_records = tx - .update_records(&q.model, q.record_filter, q.args, traceparent) + .update_records(&q.model, q.record_filter, q.args, q.limit, traceparent) .await?; Ok(QueryResult::Count(affected_records)) diff --git a/query-engine/core/src/query_ast/write.rs b/query-engine/core/src/query_ast/write.rs index 51ddf05724f..ca0287179e3 100644 --- a/query-engine/core/src/query_ast/write.rs +++ b/query-engine/core/src/query_ast/write.rs @@ -368,6 +368,7 @@ pub struct UpdateManyRecords { /// Fields of updated records that client has requested to return. /// `None` if the connector does not support returning the updated rows. pub selected_fields: Option, + pub limit: Option, } #[derive(Debug, Clone)] @@ -397,7 +398,7 @@ pub struct DeleteRecordFields { pub struct DeleteManyRecords { pub model: Model, pub record_filter: RecordFilter, - pub limit: Option, + pub limit: Option, } #[derive(Debug, Clone)] diff --git a/query-engine/core/src/query_graph_builder/write/delete.rs b/query-engine/core/src/query_graph_builder/write/delete.rs index b0c43d3d5ad..9f207144051 100644 --- a/query-engine/core/src/query_graph_builder/write/delete.rs +++ b/query-engine/core/src/query_graph_builder/write/delete.rs @@ -1,12 +1,12 @@ use super::*; -use crate::query_document::ParsedInputValue; +use crate::query_graph_builder::write::limit::validate_limit; use crate::{ query_ast::*, query_graph::{Node, QueryGraph, QueryGraphDependency}, ArgumentListLookup, FilteredQuery, ParsedField, }; use psl::datamodel_connector::ConnectorCapability; -use query_structure::{Filter, Model, PrismaValue}; +use query_structure::{Filter, Model}; use schema::{constants::args, QuerySchema}; use std::convert::TryInto; @@ -111,13 +111,8 @@ pub fn delete_many_records( Some(where_arg) => extract_filter(where_arg.value.try_into()?, &model)?, None => Filter::empty(), }; - let limit = field - .arguments - .lookup(args::LIMIT) - .and_then(|limit_arg| match limit_arg.value { - ParsedInputValue::Single(PrismaValue::Int(i)) => Some(i), - _ => None, - }); + + let limit = validate_limit(field.arguments.lookup(args::LIMIT))?; let model_id = model.primary_identifier(); let record_filter = filter.clone().into(); diff --git a/query-engine/core/src/query_graph_builder/write/limit.rs b/query-engine/core/src/query_graph_builder/write/limit.rs new file mode 100644 index 00000000000..3ce45dddf71 --- /dev/null +++ b/query-engine/core/src/query_graph_builder/write/limit.rs @@ -0,0 +1,31 @@ +use crate::query_document::{ParsedArgument, ParsedInputValue}; +use crate::query_graph_builder::{QueryGraphBuilderError, QueryGraphBuilderResult}; +use query_structure::PrismaValue; + +pub(crate) fn validate_limit(limit_arg: Option>) -> QueryGraphBuilderResult> { + let limit = limit_arg.and_then(|limit_arg| match limit_arg.value { + ParsedInputValue::Single(PrismaValue::Int(i)) => Some(i), + _ => None, + }); + + match limit { + Some(i) => { + if i < 0 { + return Err(QueryGraphBuilderError::InputError(format!( + "Provided limit ({}) must be a positive integer.", + i + ))); + } + + match usize::try_from(i) { + Ok(i) => Ok(Some(i)), + Err(_) => Err(QueryGraphBuilderError::InputError(format!( + "Provided limit ({}) is beyond max int value for platform ({}).", + i, + usize::MAX + ))), + } + } + None => Ok(None), + } +} diff --git a/query-engine/core/src/query_graph_builder/write/mod.rs b/query-engine/core/src/query_graph_builder/write/mod.rs index 8db664e91d6..d1b1e62d0c5 100644 --- a/query-engine/core/src/query_graph_builder/write/mod.rs +++ b/query-engine/core/src/query_graph_builder/write/mod.rs @@ -2,6 +2,7 @@ mod connect; mod create; mod delete; mod disconnect; +mod limit; mod nested; mod raw; mod update; diff --git a/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs b/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs index 4628912a4a4..353cb36f6da 100644 --- a/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs +++ b/query-engine/core/src/query_graph_builder/write/nested/update_nested.rs @@ -1,4 +1,5 @@ use super::*; +use crate::query_graph_builder::write::update::UpdateManyRecordNodeOptionals; use crate::{ query_ast::*, query_graph::{Node, NodeRef, QueryGraph, QueryGraphDependency}, @@ -147,9 +148,12 @@ pub fn nested_update_many( query_schema, Filter::empty(), child_model.clone(), - None, - None, data_map, + UpdateManyRecordNodeOptionals { + name: None, + nested_field_selection: None, + limit: None, + }, )?; graph.create_edge( diff --git a/query-engine/core/src/query_graph_builder/write/update.rs b/query-engine/core/src/query_graph_builder/write/update.rs index 2f9024d6932..0c262b5b459 100644 --- a/query-engine/core/src/query_graph_builder/write/update.rs +++ b/query-engine/core/src/query_graph_builder/write/update.rs @@ -1,4 +1,5 @@ use super::*; +use crate::query_graph_builder::write::limit::validate_limit; use crate::query_graph_builder::write::write_args_parser::*; use crate::ParsedObject; use crate::{ @@ -136,6 +137,9 @@ pub fn update_many_records( None => Filter::empty(), }; + // "limit" + let limit = validate_limit(field.arguments.lookup(args::LIMIT))?; + // "data" let data_argument = field.arguments.lookup(args::DATA).unwrap(); let data_map: ParsedInputMap<'_> = data_argument.value.try_into()?; @@ -146,9 +150,12 @@ pub fn update_many_records( query_schema, filter, model, - Some(field.name), - field.nested_fields.filter(|_| with_field_selection), data_map, + UpdateManyRecordNodeOptionals { + name: Some(field.name), + nested_field_selection: field.nested_fields.filter(|_| with_field_selection), + limit, + }, )?; } else { let pre_read_node = graph.create_node(utils::read_ids_infallible( @@ -161,9 +168,12 @@ pub fn update_many_records( query_schema, Filter::empty(), model.clone(), - Some(field.name), - field.nested_fields.filter(|_| with_field_selection), data_map, + UpdateManyRecordNodeOptionals { + name: Some(field.name), + nested_field_selection: field.nested_fields.filter(|_| with_field_selection), + limit, + }, )?; utils::insert_emulated_on_update(graph, query_schema, &model, &pre_read_node, &update_many_node)?; @@ -267,9 +277,8 @@ pub fn update_many_record_node( query_schema: &QuerySchema, filter: T, model: Model, - name: Option, - nested_field_selection: Option>, data_map: ParsedInputMap<'_>, + additional_args: UpdateManyRecordNodeOptionals<'_>, ) -> QueryGraphBuilderResult where T: Into, @@ -283,7 +292,7 @@ where args.update_datetimes(&model); - let selected_fields = if let Some(nested_fields) = nested_field_selection { + let selected_fields = if let Some(nested_fields) = additional_args.nested_field_selection { let (selected_fields, selection_order, nested_read) = super::read::utils::extract_selected_fields(nested_fields.fields, &model, query_schema)?; @@ -297,11 +306,12 @@ where }; let update_many = UpdateManyRecords { - name: name.unwrap_or_default(), + name: additional_args.name.unwrap_or_default(), model, record_filter, args, selected_fields, + limit: additional_args.limit, }; let update_many_node = graph.create_node(Query::Write(WriteQuery::UpdateManyRecords(update_many))); @@ -342,3 +352,9 @@ fn can_use_atomic_update( true } + +pub struct UpdateManyRecordNodeOptionals<'a> { + pub name: Option, + pub nested_field_selection: Option>, + pub limit: Option, +} diff --git a/query-engine/core/src/query_graph_builder/write/utils.rs b/query-engine/core/src/query_graph_builder/write/utils.rs index d73f5ae4a0b..b5db88b2240 100644 --- a/query-engine/core/src/query_graph_builder/write/utils.rs +++ b/query-engine/core/src/query_graph_builder/write/utils.rs @@ -229,6 +229,7 @@ where record_filter, args, selected_fields: None, + limit: None, }; graph.create_node(Query::Write(WriteQuery::UpdateManyRecords(ur))) @@ -618,6 +619,7 @@ pub fn emulate_on_delete_set_null( record_filter: RecordFilter::empty(), args: WriteArgs::new(child_update_args, crate::executor::get_request_now()), selected_fields: None, + limit: None, }); let set_null_dependents_node = graph.create_node(Query::Write(set_null_query)); @@ -770,6 +772,7 @@ pub fn emulate_on_update_set_null( record_filter: RecordFilter::empty(), args: WriteArgs::new(child_update_args, crate::executor::get_request_now()), selected_fields: None, + limit: None, }); let set_null_dependents_node = graph.create_node(Query::Write(set_null_query)); @@ -1095,6 +1098,7 @@ pub fn emulate_on_update_cascade( crate::executor::get_request_now(), ), selected_fields: None, + limit: None, }); let update_dependents_node = graph.create_node(Query::Write(update_query)); diff --git a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz index 5be449caf27..a122f9693ab 100644 Binary files a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz and b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz differ diff --git a/query-engine/schema/src/build/input_types/fields/arguments.rs b/query-engine/schema/src/build/input_types/fields/arguments.rs index 7e6f51b05a6..4b42f91fa6b 100644 --- a/query-engine/schema/src/build/input_types/fields/arguments.rs +++ b/query-engine/schema/src/build/input_types/fields/arguments.rs @@ -73,15 +73,20 @@ pub(crate) fn upsert_arguments(ctx: &QuerySchema, model: Model) -> Vec Vec> { let update_many_types = update_many_objects::update_many_input_types(ctx, model.clone(), None); let where_arg = where_argument(ctx, &model); + let limit_arg = input_field(args::LIMIT, vec![InputType::int()], None).optional(); - vec![input_field(args::DATA.to_owned(), update_many_types, None), where_arg] + vec![ + input_field(args::DATA.to_owned(), update_many_types, None), + where_arg, + limit_arg, + ] } -/// Builds "where" argument intended for the delete many field. +/// Builds "where" and "limit" argument intended for the delete many field. pub(crate) fn delete_many_arguments(ctx: &QuerySchema, model: Model) -> Vec> { let where_arg = where_argument(ctx, &model);