Skip to content

Commit

Permalink
feat: add support for nested transaction rollbacks via savepoints in sql
Browse files Browse the repository at this point in the history
This is my first OSS contribution for a Rust project, so I'm sure I've
made some stupid mistakes, but I think it should mostly work :)

This change adds a mutable depth counter, that can track how many levels
deep a transaction is, and uses savepoints to implement correct rollback
behaviour. Previously, once a nested transaction was complete, it would
be saved with `COMMIT`, meaning that even if the outer transaction was
rolled back, the operations in the inner transaction would persist. With
this change, if the outer transaction gets rolled back, then all inner
transactions will also be rolled back.

Different flavours of SQL servers have different syntax for handling
savepoints, so I've had to add new methods to the `Queryable` trait for
getting the commit and rollback statements. These are both parameterized
by the current depth.

I've additionally had to modify the `begin_statement` method to accept a depth
parameter, as it will need to conditionally create a savepoint.

When opening a transaction via the transaction server, you can now pass
the prior transaction ID to re-use the existing transaction,
incrementing the depth.

Signed-off-by: Lucian Buzzo <[email protected]>
  • Loading branch information
LucianBuzzo committed Nov 28, 2023
1 parent 5c4707e commit f486ca2
Show file tree
Hide file tree
Showing 29 changed files with 615 additions and 124 deletions.
54 changes: 47 additions & 7 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use futures::lock::Mutex;
use std::{
convert::TryFrom,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tiberius::*;
Expand All @@ -44,11 +47,13 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand All @@ -59,6 +64,7 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -90,6 +96,7 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down Expand Up @@ -229,8 +236,41 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN TRAN".to_string()
};

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
let ret = if depth > 1 {
" ".to_string()
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

fn requires_isolation_first(&self) -> bool {
Expand Down
39 changes: 38 additions & 1 deletion quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use mysql_async::{
};
use std::{
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -74,6 +77,7 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl Mysql {
Expand All @@ -87,6 +91,7 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -294,4 +299,36 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}
39 changes: 38 additions & 1 deletion quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use std::{
fmt::{Debug, Display},
fs,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio_postgres::{config::ChannelBinding, Client, Config, Statement};
Expand All @@ -50,6 +53,7 @@ pub struct PostgreSql {
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -243,6 +247,7 @@ impl PostgreSql {
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down Expand Up @@ -523,6 +528,38 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
39 changes: 35 additions & 4 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,35 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -117,10 +144,14 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, opts).await?,
))
}
}
Expand Down
41 changes: 39 additions & 2 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use tokio::sync::Mutex;

/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature.
Expand All @@ -26,6 +26,7 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl TryFrom<&str> for Sqlite {
Expand All @@ -43,7 +44,10 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite { client })
Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}
}

Expand All @@ -58,6 +62,7 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -154,6 +159,38 @@ impl Queryable for Sqlite {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit f486ca2

Please sign in to comment.