Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for nested transaction rollbacks via savepoints in sql #4637

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use futures::lock::Mutex;
use std::{
convert::TryFrom,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tiberius::*;
Expand All @@ -44,11 +47,13 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand All @@ -59,6 +64,7 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -90,6 +96,7 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down Expand Up @@ -229,8 +236,41 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN TRAN".to_string()
};

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
let ret = if depth > 1 {
" ".to_string()
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

fn requires_isolation_first(&self) -> bool {
Expand Down
39 changes: 38 additions & 1 deletion quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use mysql_async::{
};
use std::{
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -74,6 +77,7 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl Mysql {
Expand All @@ -87,6 +91,7 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -294,4 +299,36 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}
39 changes: 38 additions & 1 deletion quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use std::{
fmt::{Debug, Display},
fs,
future::Future,
sync::atomic::{AtomicBool, Ordering},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio_postgres::{config::ChannelBinding, Client, Config, Statement};
Expand All @@ -50,6 +53,7 @@ pub struct PostgreSql {
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -243,6 +247,7 @@ impl PostgreSql {
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down Expand Up @@ -523,6 +528,38 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
39 changes: 35 additions & 4 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,35 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -117,10 +144,14 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());
let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
);

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, opts).await?,
))
}
}
Expand Down
41 changes: 39 additions & 2 deletions quaint/src/connector/sqlite/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use tokio::sync::Mutex;

/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature.
Expand All @@ -26,6 +26,7 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

impl TryFrom<&str> for Sqlite {
Expand All @@ -43,7 +44,10 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite { client })
Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}
}

Expand All @@ -58,6 +62,7 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -154,6 +159,38 @@ impl Queryable for Sqlite {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

#[cfg(test)]
Expand Down
Loading
Loading