Skip to content

Commit

Permalink
feat: add support for nested transaction rollbacks via savepoints in sql
Browse files Browse the repository at this point in the history
This is my first OSS contribution for a Rust project, so I'm sure I've
made some stupid mistakes, but I think it should mostly work :)

This change adds a mutable depth counter, that can track how many levels
deep a transaction is, and uses savepoints to implement correct rollback
behaviour. Previously, once a nested transaction was complete, it would
be saved with `COMMIT`, meaning that even if the outer transaction was
rolled back, the operations in the inner transaction would persist. With
this change, if the outer transaction gets rolled back, then all inner
transactions will also be rolled back.

Different flavours of SQL servers have different syntax for handling
savepoints, so I've had to add new methods to the `Queryable` trait for
getting the commit and rollback statements. These are both parameterized
by the current depth.

I've additionally had to modify the `begin_statement` method to accept a depth
parameter, as it will need to conditionally create a savepoint.

When opening a transaction via the transaction server, you can now pass
the prior transaction ID to re-use the existing transaction,
incrementing the depth.

Signed-off-by: Lucian Buzzo <[email protected]>
  • Loading branch information
LucianBuzzo committed Nov 9, 2024
1 parent 2c64fd3 commit cd31b26
Show file tree
Hide file tree
Showing 31 changed files with 592 additions and 143 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ ensure-prisma-present:
echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \
fi \
else \
echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \
git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \
git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
fi;

# Quick schema validation of whatever you have in the dev_datamodel.prisma file.
Expand Down
34 changes: 29 additions & 5 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::lock::Mutex;
use std::borrow::Cow;
use std::{
convert::TryFrom,
future::Future,
Expand Down Expand Up @@ -48,9 +49,7 @@ impl TransactionCapable for Mssql {

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand Down Expand Up @@ -244,8 +243,33 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN TRAN")
}
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
if depth > 1 {
Cow::Owned("".to_string())
} else {
Cow::Borrowed("COMMIT")
}
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
}

fn requires_isolation_first(&self) -> bool {
Expand Down
28 changes: 28 additions & 0 deletions quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use mysql_async::{
self as my,
prelude::{Query as _, Queryable as _},
};
use std::borrow::Cow;
use std::{
future::Future,
sync::atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -347,4 +348,31 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
}
}
28 changes: 28 additions & 0 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use native_tls::{Certificate, Identity, TlsConnector};
use postgres_native_tls::MakeTlsConnector;
use postgres_types::{Kind as PostgresKind, Type as PostgresType};
use prisma_metrics::WithMetricsInstrumentation;
use std::borrow::Cow;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{
fmt::{Debug, Display},
Expand Down Expand Up @@ -806,6 +807,33 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
30 changes: 27 additions & 3 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use super::{DescribedQuery, IsolationLevel, ResultSet, Transaction};
use crate::ast::*;
use async_trait::async_trait;
Expand Down Expand Up @@ -90,8 +92,30 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN")
}
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -123,7 +147,7 @@ macro_rules! impl_default_TransactionCapable {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, opts).await?,
))
}
}
Expand Down
28 changes: 26 additions & 2 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::borrow::Cow;
use std::convert::TryFrom;
use tokio::sync::Mutex;

Expand Down Expand Up @@ -183,12 +184,35 @@ impl Queryable for Sqlite {
false
}

fn begin_statement(&self) -> &'static str {
/// Statement to begin a transaction
fn begin_statement(&self, depth: u32) -> Cow<'static, str> {
// From https://sqlite.org/isolation.html:
// `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line.
// The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers.
// If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error.
"BEGIN IMMEDIATE"
if depth > 1 {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("BEGIN IMMEDIATE")
}
}

/// Statement to commit a transaction
fn commit_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
} else {
Cow::Borrowed("COMMIT")
}
}

/// Statement to rollback a transaction
fn rollback_statement(&self, depth: u32) -> Cow<'static, str> {
if depth > 1 {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
} else {
Cow::Borrowed("ROLLBACK")
}
}
}

Expand Down
88 changes: 73 additions & 15 deletions quaint/src/connector/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::{fmt, str::FromStr};

use async_trait::async_trait;
use prisma_metrics::guards::GaugeGuard;

Expand All @@ -8,14 +6,22 @@ use crate::{
ast::*,
error::{Error, ErrorKind},
};
use std::{
fmt,
str::FromStr,
sync::{Arc, Mutex},
};

#[async_trait]
pub trait Transaction: Queryable {
/// Start a new transaction or nested transaction via savepoint.
async fn begin(&mut self) -> crate::Result<()>;

/// Commit the changes to the database and consume the transaction.
async fn commit(&self) -> crate::Result<()>;
async fn commit(&mut self) -> crate::Result<u32>;

/// Rolls back the changes to the database.
async fn rollback(&self) -> crate::Result<()>;
async fn rollback(&mut self) -> crate::Result<u32>;

/// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991
fn as_queryable(&self) -> &dyn Queryable;
Expand All @@ -36,18 +42,19 @@ pub(crate) struct TransactionOptions {
/// transaction object will panic.
pub struct DefaultTransaction<'a> {
pub inner: &'a dyn Queryable,
pub depth: Arc<Mutex<u32>>,
gauge: GaugeGuard,
}

impl<'a> DefaultTransaction<'a> {
pub(crate) async fn new(
inner: &'a dyn Queryable,
begin_stmt: &str,
tx_opts: TransactionOptions,
) -> crate::Result<DefaultTransaction<'a>> {
let this = Self {
let mut this = Self {
inner,
gauge: GaugeGuard::increment("prisma_client_queries_active"),
depth: Arc::new(Mutex::new(0)),
};

if tx_opts.isolation_first {
Expand All @@ -56,7 +63,7 @@ impl<'a> DefaultTransaction<'a> {
}
}

inner.raw_cmd(begin_stmt).await?;
this.begin().await?;

if !tx_opts.isolation_first {
if let Some(isolation) = tx_opts.isolation_level {
Expand All @@ -72,20 +79,71 @@ impl<'a> DefaultTransaction<'a> {

#[async_trait]
impl<'a> Transaction for DefaultTransaction<'a> {
/// Commit the changes to the database and consume the transaction.
async fn commit(&self) -> crate::Result<()> {
self.gauge.decrement();
self.inner.raw_cmd("COMMIT").await?;
async fn begin(&mut self) -> crate::Result<()> {
let current_depth = {
let mut depth = self.depth.lock().unwrap();
*depth += 1;
*depth
};

let begin_statement = self.inner.begin_statement(current_depth);

self.inner.raw_cmd(&begin_statement).await?;

Ok(())
}

/// Commit the changes to the database and consume the transaction.
async fn commit(&mut self) -> crate::Result<u32> {
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

// Perform the asynchronous operation without holding the lock
let commit_statement = self.inner.commit_statement(depth_val);
self.inner.raw_cmd(&commit_statement).await?;

// Lock the mutex again to modify the depth
let new_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
*depth
};

if new_depth == 0 {
self.gauge.decrement();
}

Ok(new_depth)
}

/// Rolls back the changes to the database.
async fn rollback(&self) -> crate::Result<()> {
self.gauge.decrement();
self.inner.raw_cmd("ROLLBACK").await?;
async fn rollback(&mut self) -> crate::Result<u32> {
// Lock the mutex and get the depth value
let depth_val = {
let depth = self.depth.lock().unwrap();
*depth
};

Ok(())
// Perform the asynchronous operation without holding the lock
let rollback_statement = self.inner.rollback_statement(depth_val);

self.inner.raw_cmd(&rollback_statement).await?;

// Lock the mutex again to modify the depth
let new_depth = {
let mut depth = self.depth.lock().unwrap();
*depth -= 1;
*depth
};

if new_depth == 0 {
self.gauge.decrement();
}

Ok(new_depth)
}

fn as_queryable(&self) -> &dyn Queryable {
Expand Down
Loading

0 comments on commit cd31b26

Please sign in to comment.