Skip to content

Commit

Permalink
refactor: validate provided limit value and prevent overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
FGoessler committed Jan 10, 2025
1 parent 63e697e commit d8c2f27
Show file tree
Hide file tree
Showing 16 changed files with 93 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl WriteOperations for MongoDbConnection {
model: &Model,
record_filter: connector_interface::RecordFilter,
args: WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(async move {
Expand All @@ -121,7 +121,7 @@ impl WriteOperations for MongoDbConnection {
_record_filter: connector_interface::RecordFilter,
_args: WriteArgs,
_selected_fields: FieldSelection,
_limit: Option<i64>,
_limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
Expand Down Expand Up @@ -164,7 +164,7 @@ impl WriteOperations for MongoDbConnection {
&mut self,
model: &Model,
record_filter: connector_interface::RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(write::delete_records(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
model: &Model,
record_filter: connector_interface::RecordFilter,
args: connector_interface::WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(async move {
Expand All @@ -151,7 +151,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
_record_filter: connector_interface::RecordFilter,
_args: connector_interface::WriteArgs,
_selected_fields: FieldSelection,
_limit: Option<i64>,
_limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
Expand Down Expand Up @@ -193,7 +193,7 @@ impl WriteOperations for MongoDbTransaction<'_> {
&mut self,
model: &Model,
record_filter: connector_interface::RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
_traceparent: Option<TraceParent>,
) -> connector_interface::Result<usize> {
catch(write::delete_records(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::error::MongoError::ConversionError;
use crate::{
error::{DecorateErrorWithFieldInformationExtension, MongoError},
filter::{FilterPrefix, MongoFilter, MongoFilterVisitor},
Expand Down Expand Up @@ -161,7 +162,7 @@ pub async fn update_records<'conn>(
selectors
.into_iter()
.take(match update_type {
UpdateType::Many { limit } => limit.unwrap_or(i64::MAX),
UpdateType::Many { limit } => limit.unwrap_or(usize::MAX),
UpdateType::One => 1,
} as usize)
.map(|p| {
Expand Down Expand Up @@ -232,15 +233,15 @@ pub async fn delete_records<'conn>(
session: &mut ClientSession,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
) -> crate::Result<usize> {
let coll = database.collection::<Document>(model.db_name());
let id_field = pick_singular_id(model);

let ids = if let Some(selectors) = record_filter.selectors {
selectors
.into_iter()
.take(limit.unwrap_or(i64::MAX) as usize)
.take(limit.unwrap_or(usize::MAX))
.map(|p| {
(&id_field, p.values().next().unwrap())
.into_bson()
Expand Down Expand Up @@ -309,7 +310,7 @@ async fn find_ids(
session: &mut ClientSession,
model: &Model,
filter: MongoFilter,
limit: Option<i64>,
limit: Option<usize>,
) -> crate::Result<Vec<Bson>> {
let id_field = model.primary_identifier();
let mut builder = MongoReadQueryBuilder::new(model.clone());
Expand All @@ -325,7 +326,17 @@ async fn find_ids(

let mut builder = builder.with_model_projection(id_field)?;

builder.limit = limit;
if let Some(limit) = limit {
builder.limit = match i64::try_from(limit) {
Ok(limit) => Some(limit),
Err(_) => {
return Err(ConversionError {
from: "usize".to_owned(),
to: "i64".to_owned(),
})
}
}
}

let query = builder.build()?;
let docs = query.execute(collection, session).await?;
Expand Down
6 changes: 3 additions & 3 deletions query-engine/connectors/query-connector/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ pub trait WriteOperations {
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<usize>;

Expand All @@ -300,7 +300,7 @@ pub trait WriteOperations {
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<ManyRecords>;

Expand Down Expand Up @@ -328,7 +328,7 @@ pub trait WriteOperations {
&mut self,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> crate::Result<usize>;

Expand Down
2 changes: 1 addition & 1 deletion query-engine/connectors/query-connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ pub type Result<T> = std::result::Result<T, error::ConnectorError>;
/// However when we updating any records we want to return an empty array if zero items were updated
#[derive(PartialEq)]
pub enum UpdateType {
Many { limit: Option<i64> },
Many { limit: Option<usize> },
One,
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ where
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand All @@ -243,7 +243,7 @@ where
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<ManyRecords> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand Down Expand Up @@ -274,7 +274,7 @@ where
&mut self,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub(super) async fn update_many_from_filter(
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: Option<&ModelProjection>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<Query<'static>> {
let update = build_update_and_set_query(model, args, None, ctx);
Expand Down Expand Up @@ -133,7 +133,7 @@ pub(super) async fn update_many_from_ids_and_filter(
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: Option<&ModelProjection>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<(Vec<Query<'static>>, Vec<SelectionResult>)> {
let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.filter.clone(), ctx);
Expand All @@ -145,7 +145,7 @@ pub(super) async fn update_many_from_ids_and_filter(

let updates = {
let update = build_update_and_set_query(model, args, selected_fields, ctx);
let ids: Vec<&SelectionResult> = ids.iter().take(limit.unwrap_or(i64::MAX) as usize).collect();
let ids: Vec<&SelectionResult> = ids.iter().take(limit.unwrap_or(usize::MAX)).collect();

chunk_update_with_ids(update, model, &ids, filter_condition, ctx)?
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async fn generate_updates(
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: Option<&ModelProjection>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<Vec<Query<'static>>> {
if record_filter.has_selectors() {
Expand All @@ -399,7 +399,7 @@ pub(crate) async fn update_records(
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<usize> {
if args.args.is_empty() {
Expand All @@ -421,7 +421,7 @@ pub(crate) async fn update_records_returning(
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<ManyRecords> {
let field_names: Vec<String> = selected_fields.db_names().collect();
Expand Down Expand Up @@ -458,7 +458,7 @@ pub(crate) async fn delete_records(
conn: &dyn Queryable,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> crate::Result<usize> {
let filter_condition = FilterBuilder::without_top_level_joins().visit_filter(record_filter.clone().filter, ctx);
Expand All @@ -474,7 +474,8 @@ pub(crate) async fn delete_records(
{
row_count += conn.execute(delete).await?;
if let Some(old_remaining_limit) = remaining_limit {
let new_remaining_limit = old_remaining_limit - row_count as i64;
// u64 to usize cast here cannot 'overflow' as the number of rows was limited to MAX usize in the first place.
let new_remaining_limit = old_remaining_limit - row_count as usize;
if new_remaining_limit <= 0 {

Check failure on line 479 in query-engine/connectors/sql-query-connector/src/database/operations/write.rs

View workflow job for this annotation

GitHub Actions / clippy linting

this comparison involving the minimum or maximum element for this type contains a case that is always true or always false
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> {
model: &Model,
record_filter: RecordFilter,
args: WriteArgs,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand All @@ -237,7 +237,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> {
record_filter: RecordFilter,
args: WriteArgs,
selected_fields: FieldSelection,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<ManyRecords> {
let ctx = Context::new(&self.connection_info, traceparent);
Expand Down Expand Up @@ -283,7 +283,7 @@ impl WriteOperations for SqlConnectorTransaction<'_> {
&mut self,
model: &Model,
record_filter: RecordFilter,
limit: Option<i64>,
limit: Option<usize>,
traceparent: Option<TraceParent>,
) -> connector::Result<usize> {
catch(&self.connection_info, async {
Expand Down
4 changes: 2 additions & 2 deletions query-engine/connectors/sql-query-connector/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use query_structure::*;
pub(crate) fn wrap_with_limit_subquery_if_needed<'a>(
model: &Model,
filter_condition: ConditionTree<'a>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context,
) -> ConditionTree<'a> {
if let Some(limit) = limit {
Expand All @@ -22,7 +22,7 @@ pub(crate) fn wrap_with_limit_subquery_if_needed<'a>(
Select::from_table(model.as_table(ctx))
.columns(columns)
.so_that(filter_condition)
.limit(limit as usize),
.limit(limit),
),
)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ pub(crate) fn delete_returning(
pub(crate) fn delete_many_from_filter(
model: &Model,
filter_condition: ConditionTree<'static>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> Query<'static> {
let filter_condition = wrap_with_limit_subquery_if_needed(model, filter_condition, limit, ctx);
Expand All @@ -242,7 +242,7 @@ pub(crate) fn delete_many_from_ids_and_filter(
model: &Model,
ids: &[&SelectionResult],
filter_condition: ConditionTree<'static>,
limit: Option<i64>,
limit: Option<usize>,
ctx: &Context<'_>,
) -> Vec<Query<'static>> {
let columns: Vec<_> = ModelProjection::from(model.primary_identifier())
Expand Down
4 changes: 2 additions & 2 deletions query-engine/core/src/query_ast/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ pub struct UpdateManyRecords {
/// Fields of updated records that client has requested to return.
/// `None` if the connector does not support returning the updated rows.
pub selected_fields: Option<UpdateManyRecordsFields>,
pub limit: Option<i64>,
pub limit: Option<usize>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -398,7 +398,7 @@ pub struct DeleteRecordFields {
pub struct DeleteManyRecords {
pub model: Model,
pub record_filter: RecordFilter,
pub limit: Option<i64>,
pub limit: Option<usize>,
}

#[derive(Debug, Clone)]
Expand Down
16 changes: 7 additions & 9 deletions query-engine/core/src/query_graph_builder/write/delete.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use super::*;
use crate::query_document::ParsedInputValue;
use crate::query_graph_builder::write::limit::validate_limit;
use crate::{
query_ast::*,
query_graph::{Node, QueryGraph, QueryGraphDependency},
ArgumentListLookup, FilteredQuery, ParsedField,
};
use psl::datamodel_connector::ConnectorCapability;
use query_structure::{Filter, Model, PrismaValue};
use query_structure::{Filter, Model};
use schema::{constants::args, QuerySchema};
use std::convert::TryInto;

Expand Down Expand Up @@ -111,13 +111,11 @@ pub fn delete_many_records(
Some(where_arg) => extract_filter(where_arg.value.try_into()?, &model)?,
None => Filter::empty(),
};
let limit = field
.arguments
.lookup(args::LIMIT)
.and_then(|limit_arg| match limit_arg.value {
ParsedInputValue::Single(PrismaValue::Int(i)) => Some(i),
_ => None,
});

let limit = match validate_limit(field.arguments.lookup(args::LIMIT)) {
Ok(limit) => limit,
Err(err) => return Err(err),
};

let model_id = model.primary_identifier();
let record_filter = filter.clone().into();
Expand Down
31 changes: 31 additions & 0 deletions query-engine/core/src/query_graph_builder/write/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::query_document::{ParsedArgument, ParsedInputValue};
use crate::query_graph_builder::{QueryGraphBuilderError, QueryGraphBuilderResult};
use query_structure::PrismaValue;

pub(crate) fn validate_limit<'a>(limit_arg: Option<ParsedArgument<'a>>) -> QueryGraphBuilderResult<Option<usize>> {

Check failure on line 5 in query-engine/core/src/query_graph_builder/write/limit.rs

View workflow job for this annotation

GitHub Actions / clippy linting

the following explicit lifetimes could be elided: 'a
let limit = limit_arg.and_then(|limit_arg| match limit_arg.value {
ParsedInputValue::Single(PrismaValue::Int(i)) => Some(i),
_ => None,
});

match limit {
Some(i) => {
if i < 0 {
return Err(QueryGraphBuilderError::InputError(format!(
"Provided limit ({}) must be a positive integer.",
i
)));
}

match usize::try_from(i) {
Ok(i) => Ok(Some(i)),
Err(_) => Err(QueryGraphBuilderError::InputError(format!(
"Provided limit ({}) is beyond max int value for platform ({}).",
i,
usize::MAX
))),
}
}
None => Ok(None),
}
}
Loading

0 comments on commit d8c2f27

Please sign in to comment.