From 29c424a601638eb3f7c6e163ab70881c72af54b9 Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Sun, 15 Oct 2023 13:58:28 +0100 Subject: [PATCH] feat: add support for nested transaction rollbacks via savepoints in sql This is my first OSS contribution for a Rust project, so I'm sure I've made some stupid mistakes, but I think it should mostly work :) This change adds a mutable depth counter, that can track how many levels deep a transaction is, and uses savepoints to implement correct rollback behaviour. Previously, once a nested transaction was complete, it would be saved with `COMMIT`, meaning that even if the outer transaction was rolled back, the operations in the inner transaction would persist. With this change, if the outer transaction gets rolled back, then all inner transactions will also be rolled back. Different flavours of SQL servers have different syntax for handling savepoints, so I've had to add new methods to the `Queryable` trait for getting the commit and rollback statements. These are both parameterized by the current depth. I've additionally had to modify the `begin_statement` method to accept a depth parameter, as it will need to conditionally create a savepoint. When opening a transaction via the transaction server, you can now pass the prior transaction ID to re-use the existing transaction, incrementing the depth. Signed-off-by: Lucian Buzzo --- quaint/src/connector/mssql/native/mod.rs | 54 +++++++- quaint/src/connector/mysql/native/mod.rs | 39 +++++- quaint/src/connector/postgres/native/mod.rs | 39 +++++- quaint/src/connector/queryable.rs | 39 +++++- quaint/src/connector/sqlite/native/mod.rs | 41 +++++- quaint/src/connector/transaction.rs | 75 +++++++++-- quaint/src/pooled.rs | 5 +- quaint/src/pooled/manager.rs | 15 ++- quaint/src/single.rs | 21 ++- quaint/src/tests/query.rs | 16 ++- quaint/src/tests/query/error.rs | 2 +- .../tests/new/interactive_tx.rs | 127 ++++++++++++++---- .../query-engine-tests/tests/new/metrics.rs | 4 +- .../tests/new/regressions/prisma_13405.rs | 2 +- .../tests/new/regressions/prisma_15607.rs | 2 +- .../new/regressions/prisma_engines_4286.rs | 6 +- .../query-tests-setup/src/runner/mod.rs | 3 +- .../src/interface/transaction.rs | 12 +- .../query-connector/src/interface.rs | 5 +- .../src/database/transaction.rs | 18 +-- query-engine/core/src/executor/mod.rs | 10 +- .../interactive_transactions/actor_manager.rs | 34 +++-- .../src/interactive_transactions/actors.rs | 67 +++++++-- .../src/interactive_transactions/messages.rs | 8 +- .../core/src/interactive_transactions/mod.rs | 4 +- query-engine/driver-adapters/src/proxy.rs | 10 ++ query-engine/driver-adapters/src/queryable.rs | 16 ++- .../driver-adapters/src/transaction.rs | 56 +++++++- query-engine/query-engine/src/server/mod.rs | 6 +- 29 files changed, 610 insertions(+), 126 deletions(-) diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 124e14ac94d0..13ff8dffdd00 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -17,7 +17,10 @@ use futures::lock::Mutex; use std::{ convert::TryFrom, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tiberius::*; @@ -44,11 +47,13 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } } @@ -59,6 +64,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -90,6 +96,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -229,8 +236,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 98feb2649763..662231164c89 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -21,7 +21,10 @@ use mysql_async::{ }; use std::{ future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio::sync::Mutex; @@ -74,6 +77,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } impl Mysql { @@ -87,6 +91,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -294,4 +299,36 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index d656bceb1e00..c48eb33b6f38 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -24,7 +24,10 @@ use std::{ fmt::{Debug, Display}, fs, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; @@ -50,6 +53,7 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex>, is_healthy: AtomicBool, + transaction_depth: Arc>, } #[derive(Debug)] @@ -243,6 +247,7 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -523,6 +528,38 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 09dbc7abba4c..10e551af4ba7 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -87,8 +87,35 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } /// Sets the transaction isolation level to given value. @@ -117,10 +144,14 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 3bf0c46a7db5..a11d9d00f5aa 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -16,7 +16,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use tokio::sync::Mutex; /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. @@ -26,6 +26,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } impl TryFrom<&str> for Sqlite { @@ -43,7 +44,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -58,6 +62,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -154,6 +159,38 @@ impl Queryable for Sqlite { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } #[cfg(test)] diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a8..0a5a9008e665 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,18 +4,22 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, sync::Arc}; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + + /// Commit the changes to the database and consume the transaction. + async fn commit(&mut self) -> crate::Result; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +31,9 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,15 +43,18 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, } impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, - begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let mut this = Self { + inner, + depth: tx_opts.depth, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -52,7 +62,7 @@ impl<'a> DefaultTransaction<'a> { } } - inner.raw_cmd(begin_stmt).await?; + this.begin().await?; if !tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,27 +72,63 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { + async fn begin(&mut self) -> crate::Result<()> { + increment_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.inner.begin_statement(st_depth).await; + + self.inner.raw_cmd(&begin_statement).await?; + + Ok(()) + } + /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let commit_statement = self.inner.commit_statement(st_depth).await; + + self.inner.raw_cmd(&commit_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let rollback_statement = self.inner.rollback_statement(st_depth).await; + + self.inner.raw_cmd(&rollback_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -190,10 +236,11 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new(isolation_level: Option, isolation_first: bool, depth: Arc>) -> Self { Self { isolation_level, isolation_first, + depth, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 4c4152923377..458a3412ecec 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -500,7 +500,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 73441b7609ba..087ea01e5ce3 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,12 +10,15 @@ use crate::{ error::Error, }; use async_trait::async_trait; +use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -62,8 +65,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index b819259d81c7..cf1c013467b5 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,6 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; +use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] @@ -15,6 +16,7 @@ use std::convert::TryFrom; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -162,7 +164,11 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { + inner, + connection_info, + transaction_depth: Arc::new(Mutex::new(0)), + }) } #[cfg(feature = "sqlite-native")] @@ -175,6 +181,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_DATABASE.to_owned(), }), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -229,8 +236,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a9601..cf471fbf7330 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d3..67334858576e 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index 4372b23c282d..b0cfc27f320c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -8,7 +8,7 @@ mod interactive_tx { #[connector_test] async fn basic_commit_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -35,7 +35,7 @@ mod interactive_tx { #[connector_test] async fn basic_rollback_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -63,7 +63,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one second. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -85,7 +85,6 @@ mod interactive_tx { let error = res.err().unwrap(); let known_err = error.as_known().unwrap(); - println!("KNOWN ERROR {known_err:?}"); assert_eq!(known_err.error_code, Cow::Borrowed("P2028")); assert!(known_err @@ -108,7 +107,7 @@ mod interactive_tx { #[connector_test] async fn no_auto_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -135,7 +134,7 @@ mod interactive_tx { #[connector_test(only(Postgres))] async fn raw_queries(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -164,7 +163,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_success(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -190,7 +189,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -216,7 +215,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // One dup key, will cause failure of the batch. @@ -259,7 +258,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -328,10 +327,10 @@ mod interactive_tx { #[connector_test(exclude(Sqlite))] async fn multiple_tx(mut runner: Runner) -> TestResult<()> { // First transaction. - let tx_id_a = runner.start_tx(2000, 2000, None).await?; + let tx_id_a = runner.start_tx(2000, 2000, None, None).await?; // Second transaction. - let tx_id_b = runner.start_tx(2000, 2000, None).await?; + let tx_id_b = runner.start_tx(2000, 2000, None, None).await?; // Execute on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -379,10 +378,10 @@ mod interactive_tx { ); // First transaction. - let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Second transaction. - let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Read on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -421,7 +420,7 @@ mod interactive_tx { #[connector_test] async fn double_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -456,9 +455,81 @@ mod interactive_tx { Ok(()) } + #[connector_test(only(Postgres))] + async fn nested_commit_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + let res = runner.commit_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + Ok(()) + } + + #[connector_test(only(Postgres))] + async fn nested_commit_rollback_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + // Now rollback the outer transaction + let res = runner.rollback_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + // Assert that no records were written to the DB + let result_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(result_tx_id.clone()); + insta::assert_snapshot!( + run_query!(&runner, r#"query { findManyTestModel { id field }}"#), + @r###"{"data":{"findManyTestModel":[]}}"### + ); + let _ = runner.commit_tx(result_tx_id).await?; + + Ok(()) + } + #[connector_test] async fn double_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -495,7 +566,7 @@ mod interactive_tx { #[connector_test] async fn commit_after_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -532,7 +603,7 @@ mod interactive_tx { #[connector_test] async fn rollback_after_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -575,7 +646,9 @@ mod itx_isolation { // All (SQL) connectors support serializable. #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Serializable".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -597,7 +670,9 @@ mod itx_isolation { #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -608,13 +683,17 @@ mod itx_isolation { #[connector_test(only(Postgres))] async fn spacing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Repeatable Read".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; assert!(res.is_ok()); - let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("RepeatableRead".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -625,7 +704,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb))] async fn invalid_isolation(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected invalid isolation level string to throw an error, but it succeeded instead."), @@ -638,7 +717,7 @@ mod itx_isolation { // Mongo doesn't support isolation levels. #[connector_test(only(MongoDb))] async fn mongo_failure(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected mongo to throw an unsupported error, but it succeeded instead."), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index dff1ecdb03a5..35c550494de6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -49,7 +49,7 @@ mod metrics { #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -66,7 +66,7 @@ mod metrics { let active_transactions = get_gauge(&json, PRISMA_CLIENT_QUERIES_ACTIVE); assert_eq!(active_transactions, 0.0); - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs index a9b6c4395760..49ea6597ff6b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs @@ -90,7 +90,7 @@ mod mongodb { } async fn start_itx(runner: &mut Runner) -> TestResult { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); Ok(tx_id) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index 3ab34b12010a..ebd8accfb356 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -82,7 +82,7 @@ impl Actor { response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = with_logs(runner.start_tx(10000, 10000, None, None), log_tx.clone()).await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs index 313a29cdacf4..0b1e3244e420 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs @@ -5,7 +5,9 @@ mod sqlite { #[connector_test] async fn close_tx_on_error(runner: Runner) -> TestResult<()> { // Try to open a transaction with unsupported isolation error in SQLite. - let result = runner.start_tx(2000, 5000, Some("ReadUncommitted".to_owned())).await; + let result = runner + .start_tx(2000, 5000, Some("ReadUncommitted".to_owned()), None) + .await; assert!(result.is_err()); // Without the changes from https://github.com/prisma/prisma-engines/pull/4286 or @@ -16,7 +18,7 @@ mod sqlite { // IMMEDIATE if we had control over SQLite transaction type here, as that would not rely on // both transactions using the same connection if we were to pool multiple SQLite // connections in the future. - let tx = runner.start_tx(2000, 5000, None).await?; + let tx = runner.start_tx(2000, 5000, None, None).await?; runner.rollback_tx(tx).await?.unwrap(); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index 194b40f15f62..68c499f6634d 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -366,8 +366,9 @@ impl Runner { max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, + new_tx_id: Option, ) -> TestResult { - let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level); + let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level, new_tx_id); match &self.executor { RunnerExecutor::Builtin(executor) => { let id = executor diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 6e15d1262123..c6de318194c8 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -40,17 +40,21 @@ impl<'conn> MongoDbTransaction<'conn> { #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { - async fn commit(&mut self) -> connector_interface::Result<()> { + async fn begin(&mut self) -> connector_interface::Result<()> { + Ok(()) + } + + async fn commit(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } - async fn rollback(&mut self) -> connector_interface::Result<()> { + async fn rollback(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); self.connection @@ -59,7 +63,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } fn as_connection_like(&mut self) -> &mut dyn ConnectionLike { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index 518f4356d547..a1c80199219d 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -30,8 +30,9 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { - async fn commit(&mut self) -> crate::Result<()>; - async fn rollback(&mut self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result; /// Explicit upcast of self reference. Rusts current vtable layout doesn't allow for an upcast if /// `trait A`, `trait B: A`, so that `Box as Box` works. This is a simple, explicit workaround. diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 35adddb52ab4..e4c3d3b69bfb 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -37,21 +37,23 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { - async fn commit(&mut self) -> connector::Result<()> { + async fn begin(&mut self) -> connector::Result<()> { catch(self.connection_info.clone(), async move { - self.inner.commit().await.map_err(SqlError::from) + self.inner.begin().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result<()> { + async fn commit(&mut self) -> connector::Result { catch(self.connection_info.clone(), async move { - let res = self.inner.rollback().await.map_err(SqlError::from); + self.inner.commit().await.map_err(SqlError::from) + }) + .await + } - match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), - _ => res, - } + async fn rollback(&mut self) -> connector::Result { + catch(self.connection_info.clone(), async move { + self.inner.rollback().await.map_err(SqlError::from) }) .await } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index fee7bc68fe7b..2316267c2345 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -73,17 +73,21 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option, } impl TransactionOptions { - pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option) -> Self { + pub fn new( + max_acquisition_millis: u64, + valid_for_millis: u64, + isolation_level: Option, + new_tx_id: Option, + ) -> Self { Self { max_acquisition_millis, valid_for_millis, isolation_level, - new_tx_id: None, + new_tx_id, } } diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index e6c1c7fbd1dc..37dae7e57332 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -72,19 +72,27 @@ impl TransactionActorManager { timeout: Duration, engine_protocol: EngineProtocol, ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); + // Only create a client if there is no client for this transaction yet. + // otherwise, begin a new transaction/savepoint for the existing client. + if !self.clients.read().await.contains_key(&tx_id) { + let client = spawn_itx_actor( + query_schema.clone(), + tx_id.clone(), + conn, + isolation_level, + timeout, + CHANNEL_SIZE, + self.send_done.clone(), + engine_protocol, + ) + .await?; + + self.clients.write().await.insert(tx_id, client); + } else { + let client = self.get_client(&tx_id, "begin").await?; + client.begin().await?; + } + Ok(()) } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 86ebd5c13b84..4a74867885b4 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -66,15 +66,39 @@ impl<'a> ITXServer<'a> { let _ = op.respond_to.send(TxOpResponse::Batch(result)); RunState::Continue } + TxOpRequestMsg::Begin => { + let resp = self.begin().await; + let _ = op.respond_to.send(TxOpResponse::Begin(resp)); + RunState::Continue + } TxOpRequestMsg::Commit => { let resp = self.commit().await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; + let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } TxOpRequestMsg::Rollback => { let resp = self.rollback(false).await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } } } @@ -118,32 +142,46 @@ impl<'a> ITXServer<'a> { .await } - pub(crate) async fn commit(&mut self) -> crate::Result<()> { + pub(crate) async fn begin(&mut self) -> crate::Result<()> { if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; + trace!("[{}] beginning.", self.id.to_string()); + open_tx.begin().await?; } Ok(()) } - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub(crate) async fn commit(&mut self) -> crate::Result { + if let CachedTx::Open(_) = self.cached_tx { + let open_tx = self.cached_tx.as_open()?; + trace!("[{}] committing.", self.id.to_string()); + let depth = open_tx.commit().await?; + if depth == 0 { + self.cached_tx = CachedTx::Committed; + } + return Ok(depth); + } + + Ok(0) + } + + pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result { debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; + let depth = open_tx.rollback().await?; if was_timeout { trace!("[{}] Expired Rolling back", self.id.to_string()); self.cached_tx = CachedTx::Expired; - } else { + } else if depth == 0 { self.cached_tx = CachedTx::RolledBack; trace!("[{}] Rolling back", self.id.to_string()); } + return Ok(depth); } - Ok(()) + Ok(0) } pub(crate) fn name(&self) -> String { @@ -158,7 +196,12 @@ pub struct ITXClient { } impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { + pub async fn begin(&self) -> crate::Result<()> { + self.send_and_receive(TxOpRequestMsg::Begin).await?; + Ok(()) + } + + pub(crate) async fn commit(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; if let TxOpResponse::Committed(resp) = msg { @@ -169,7 +212,7 @@ impl ITXClient { } } - pub(crate) async fn rollback(&self) -> crate::Result<()> { + pub(crate) async fn rollback(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; if let TxOpResponse::RolledBack(resp) = msg { diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs index 0dba2c096a8a..8f64a2fb712e 100644 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ b/query-engine/core/src/interactive_transactions/messages.rs @@ -6,6 +6,7 @@ use tokio::sync::oneshot; pub enum TxOpRequestMsg { Commit, Rollback, + Begin, Single(Operation, Option), Batch(Vec, Option), } @@ -18,6 +19,7 @@ pub struct TxOpRequest { impl Display for TxOpRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.msg { + TxOpRequestMsg::Begin => write!(f, "Begin"), TxOpRequestMsg::Commit => write!(f, "Commit"), TxOpRequestMsg::Rollback => write!(f, "Rollback"), TxOpRequestMsg::Single(..) => write!(f, "Single"), @@ -28,8 +30,9 @@ impl Display for TxOpRequest { #[derive(Debug)] pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), + Begin(crate::Result<()>), + Committed(crate::Result), + RolledBack(crate::Result), Single(crate::Result), Batch(crate::Result>>), } @@ -37,6 +40,7 @@ pub enum TxOpResponse { impl Display for TxOpResponse { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::Begin(..) => write!(f, "Begin"), Self::Committed(..) => write!(f, "Committed"), Self::RolledBack(..) => write!(f, "RolledBack"), Self::Single(..) => write!(f, "Single"), diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index c3ee76703a06..c1f869e86b0c 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,7 +1,7 @@ use crate::CoreError; use connector::Transaction; use crosstarget_utils::time::ElapsedTimeCounter; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use tokio::time::Duration; @@ -39,7 +39,7 @@ pub(crate) use messages::*; /// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction /// rather than an error message stating that the transaction does not exist. -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] pub struct TxId(String); const MINIMUM_TX_ID_LENGTH: usize = 24; diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 8e1d39138cb6..11ad44772961 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -38,6 +38,9 @@ pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, + /// begin transaction + pub begin: AdapterMethod<(), ()>, + /// commit transaction commit: AdapterMethod<(), ()>, @@ -106,11 +109,13 @@ impl DriverProxy { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> JsResult { let commit = get_named_property(js_transaction, "commit")?; + let begin = get_named_property(js_transaction, "begin")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; let options = from_js_value::(options); Ok(Self { + begin, commit, rollback, options, @@ -122,6 +127,11 @@ impl TransactionProxy { &self.options } + pub fn begin(&self) -> UnsafeFuture> + '_> { + self.closed.store(true, Ordering::Relaxed); + UnsafeFuture(self.begin.call_as_async(())) + } + /// Commits the transaction via the driver adapter. /// /// ## Cancellation safety diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index a4599019003e..2ff66ac57c7b 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -5,7 +5,7 @@ use crate::JsObject; use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; -use futures::Future; +use futures::{lock::Mutex, Future}; use quaint::connector::{ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, @@ -13,6 +13,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -207,6 +208,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -292,14 +294,19 @@ impl TransactionCapable for JsQueryable { } } - let begin_stmt = tx.begin_statement(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_stmt = tx.begin_statement(st_depth).await; let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(begin_stmt).await?; + tx.raw_cmd(&begin_stmt).await?; } if !isolation_first { @@ -321,5 +328,6 @@ pub fn from_js(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 264c363ea608..1b660e7960b9 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,10 +1,12 @@ use async_trait::async_trait; +use futures::lock::Mutex; use metrics::decrement_gauge; use quaint::{ connector::{IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::proxy::{TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; @@ -16,11 +18,20 @@ use crate::{JsObject, JsResult}; pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + commit_stmt: "COMMIT".to_string(), + rollback_stmt: "ROLLBACK".to_string(), + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -35,11 +46,31 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn begin(&mut self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + let commit_stmt = "BEGIN"; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(commit_stmt).await?; + } + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + UnsafeFuture(self.tx_proxy.begin()).await + } + async fn commit(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + let commit_stmt = &self.commit_stmt; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); @@ -48,14 +79,20 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(commit_stmt).await?; } - UnsafeFuture(self.tx_proxy.commit()).await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = UnsafeFuture(self.tx_proxy.commit()).await; + + Ok(*depth_guard) } - async fn rollback(&self) -> quaint::Result<()> { + async fn rollback(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = &self.rollback_stmt; if self.options().use_phantom_query { let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); @@ -64,7 +101,12 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(rollback_stmt).await?; } - UnsafeFuture(self.tx_proxy.rollback()).await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = UnsafeFuture(self.tx_proxy.rollback()).await; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index f3583df310d7..ba1f4d4f13bd 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -282,7 +282,11 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); + let tx_id = if tx_opts.new_tx_id.is_none() { + tx_opts.with_new_transaction_id() + } else { + tx_opts.new_tx_id.clone().unwrap() + }; // This is the span we use to instrument the execution of a transaction. This span will be open // during the tx execution, and held in the ITXServer for that transaction (see ITXServer])