Skip to content

Commit

Permalink
feat: create related records in bulk (#4698)
Browse files Browse the repository at this point in the history
  • Loading branch information
laplab authored Feb 26, 2024
1 parent af6ceee commit 45a92dc
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 153 deletions.
3 changes: 3 additions & 0 deletions quaint/src/ast/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct MultiRowInsert<'a> {
pub(crate) table: Option<Table<'a>>,
pub(crate) columns: Vec<Column<'a>>,
pub(crate) values: Vec<Row<'a>>,
pub(crate) returning: Option<Vec<Column<'a>>>,
}

/// `INSERT` conflict resolution strategies.
Expand Down Expand Up @@ -186,6 +187,7 @@ impl<'a> Insert<'a> {
table: Some(table.into()),
columns: columns.into_iter().map(|c| c.into()).collect(),
values: Vec::new(),
returning: None,
}
}

Expand All @@ -198,6 +200,7 @@ impl<'a> Insert<'a> {
table: None,
columns: columns.into_iter().map(|c| c.into()).collect(),
values: Vec::new(),
returning: None,
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ impl WriteOperations for MongoDbConnection {
.await
}

async fn create_records_returning(
&mut self,
_model: &Model,
_args: Vec<WriteArgs>,
_skip_duplicates: bool,
_selected_fields: FieldSelection,
_trace_id: Option<String>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
}

async fn update_records(
&mut self,
model: &Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> {
.await
}

async fn create_records_returning(
&mut self,
_model: &Model,
_args: Vec<connector_interface::WriteArgs>,
_skip_duplicates: bool,
_selected_fields: FieldSelection,
_trace_id: Option<String>,
) -> connector_interface::Result<ManyRecords> {
unimplemented!()
}

async fn update_records(
&mut self,
model: &Model,
Expand Down
20 changes: 17 additions & 3 deletions query-engine/connectors/query-connector/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ pub trait WriteOperations {
trace_id: Option<String>,
) -> crate::Result<usize>;

/// Inserts many records at once into the database and returns their
/// selected fields.
/// This method should not be used if the connector does not support
/// returning created rows.
async fn create_records_returning(
&mut self,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: FieldSelection,
trace_id: Option<String>,
) -> crate::Result<ManyRecords>;

/// Update records in the `Model` with the given `WriteArgs` filtered by the
/// `Filter`.
async fn update_records(
Expand Down Expand Up @@ -299,9 +312,10 @@ pub trait WriteOperations {
trace_id: Option<String>,
) -> crate::Result<usize>;

/// Delete single record in the `Model` with the given `Filter`.
/// Return selected fields of the deleted record, if the connector
/// supports it. If the connector does not support it, error is returned.
/// Delete single record in the `Model` with the given `Filter` and returns
/// selected fields of the deleted record.
/// This method should not be used if the connector does not support returning
/// deleted rows.
async fn delete_record(
&mut self,
model: &Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,23 @@ where
let ctx = Context::new(&self.connection_info, trace_id.as_deref());
catch(
&self.connection_info,
write::create_records(&self.inner, model, args, skip_duplicates, &ctx),
write::create_records_count(&self.inner, model, args, skip_duplicates, &ctx),
)
.await
}

async fn create_records_returning(
&mut self,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: FieldSelection,
trace_id: Option<String>,
) -> connector::Result<ManyRecords> {
let ctx = Context::new(&self.connection_info, trace_id.as_deref());
catch(
&self.connection_info,
write::create_records_returning(&self.inner, model, args, skip_duplicates, selected_fields, &ctx),
)
.await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
};
use connector_interface::*;
use itertools::Itertools;
use quaint::ast::Insert;
use quaint::{
error::ErrorKind,
prelude::{native_uuid, uuid_to_bin, uuid_to_bin_swapped, Aliasable, Select, SqlFamily},
Expand Down Expand Up @@ -194,45 +195,96 @@ pub(crate) async fn create_record(
}
}

pub(crate) async fn create_records(
conn: &dyn Queryable,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
ctx: &Context<'_>,
) -> crate::Result<usize> {
if args.is_empty() {
return Ok(0);
}

// Compute the set of fields affected by the createMany.
/// Returns a set of fields that are used in the arguments for the create operation.
fn collect_affected_fields(args: &[WriteArgs], model: &Model) -> HashSet<ScalarFieldRef> {
let mut fields = HashSet::new();
args.iter().for_each(|arg| fields.extend(arg.keys()));

#[allow(clippy::mutable_key_type)]
let affected_fields: HashSet<ScalarFieldRef> = fields
fields
.into_iter()
.map(|dsfn| model.fields().scalar().find(|sf| sf.db_name() == dsfn.deref()).unwrap())
.collect();
.collect()
}

/// Generates a list of insert statements to execute. If `selected_fields` is set, insert statements
/// will return the specified columns of inserted rows.
fn generate_insert_statements(
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: Option<&ModelProjection>,
ctx: &Context<'_>,
) -> Vec<Insert<'static>> {
let affected_fields = collect_affected_fields(&args, model);

if affected_fields.is_empty() {
// If no fields are to be inserted (everything is DEFAULT) we need to fall back to inserting default rows `args.len()` times.
create_many_empty(conn, model, args.len(), skip_duplicates, ctx).await
args.into_iter()
.map(|_| write::create_records_empty(model, skip_duplicates, selected_fields, ctx))
.collect()
} else {
create_many_nonempty(conn, model, args, skip_duplicates, affected_fields, ctx).await
let partitioned_batches = partition_into_batches(args, ctx);
trace!("Total of {} batches to be executed.", partitioned_batches.len());
trace!(
"Batch sizes: {:?}",
partitioned_batches.iter().map(|b| b.len()).collect_vec()
);

partitioned_batches
.into_iter()
.map(|batch| {
write::create_records_nonempty(model, batch, skip_duplicates, &affected_fields, selected_fields, ctx)
})
.collect()
}
}

/// Standard create many records, requires `affected_fields` to be non-empty.
#[allow(clippy::mutable_key_type)]
async fn create_many_nonempty(
/// Inserts records specified as a list of `WriteArgs`. Returns number of inserted records.
pub(crate) async fn create_records_count(
conn: &dyn Queryable,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
affected_fields: HashSet<ScalarFieldRef>,
ctx: &Context<'_>,
) -> crate::Result<usize> {
let inserts = generate_insert_statements(model, args, skip_duplicates, None, ctx);
let mut count = 0;
for insert in inserts {
count += conn.execute(insert.into()).await?;
}

Ok(count as usize)
}

/// Inserts records specified as a list of `WriteArgs`. Returns values of fields specified in
/// `selected_fields` for all inserted rows.
pub(crate) async fn create_records_returning(
conn: &dyn Queryable,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: FieldSelection,
ctx: &Context<'_>,
) -> crate::Result<ManyRecords> {
let field_names: Vec<String> = selected_fields.db_names().collect();
let idents = selected_fields.type_identifiers_with_arities();
let meta = column_metadata::create(&field_names, &idents);
let mut records = ManyRecords::new(field_names.clone());
let inserts = generate_insert_statements(model, args, skip_duplicates, Some(&selected_fields.into()), ctx);
for insert in inserts {
let result_set = conn.query(insert.into()).await?;
for result_row in result_set {
let sql_row = result_row.to_sql_row(&meta)?;
let record = Record::from(sql_row);
records.push(record);
}
}

Ok(records)
}

/// Partitions data into batches, respecting `max_bind_values` and `max_insert_rows` settings from
/// the `Context`.
fn partition_into_batches(args: Vec<WriteArgs>, ctx: &Context<'_>) -> Vec<Vec<WriteArgs>> {
let batches = if let Some(max_params) = ctx.max_bind_values {
// We need to split inserts if they are above a parameter threshold, as well as split based on number of rows.
// -> Horizontal partitioning by row number, vertical by number of args.
Expand Down Expand Up @@ -274,7 +326,7 @@ async fn create_many_nonempty(
vec![args]
};

let partitioned_batches = if let Some(max_rows) = ctx.max_insert_rows {
if let Some(max_rows) = ctx.max_insert_rows {
let capacity = batches.len();
batches
.into_iter()
Expand All @@ -295,39 +347,7 @@ async fn create_many_nonempty(
})
} else {
batches
};

trace!("Total of {} batches to be executed.", partitioned_batches.len());
trace!(
"Batch sizes: {:?}",
partitioned_batches.iter().map(|b| b.len()).collect_vec()
);

let mut count = 0;
for batch in partitioned_batches {
let stmt = write::create_records_nonempty(model, batch, skip_duplicates, &affected_fields, ctx);
count += conn.execute(stmt.into()).await?;
}

Ok(count as usize)
}

/// Creates many empty (all default values) rows.
async fn create_many_empty(
conn: &dyn Queryable,
model: &Model,
num_records: usize,
skip_duplicates: bool,
ctx: &Context<'_>,
) -> crate::Result<usize> {
let stmt = write::create_records_empty(model, skip_duplicates, ctx);
let mut count = 0;

for _ in 0..num_records {
count += conn.execute(stmt.clone().into()).await?;
}

Ok(count as usize)
}

/// Update one record in a database defined in `conn` and the records
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,30 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> {
let ctx = Context::new(&self.connection_info, trace_id.as_deref());
catch(
&self.connection_info,
write::create_records(self.inner.as_queryable(), model, args, skip_duplicates, &ctx),
write::create_records_count(self.inner.as_queryable(), model, args, skip_duplicates, &ctx),
)
.await
}

async fn create_records_returning(
&mut self,
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: FieldSelection,
trace_id: Option<String>,
) -> connector::Result<ManyRecords> {
let ctx = Context::new(&self.connection_info, trace_id.as_deref());
catch(
&self.connection_info,
write::create_records_returning(
self.inner.as_queryable(),
model,
args,
skip_duplicates,
selected_fields,
&ctx,
),
)
.await
}
Expand Down
Loading

0 comments on commit 45a92dc

Please sign in to comment.