diff --git a/Makefile b/Makefile index 7407f7d41fe3..be6a9c8be76b 100644 --- a/Makefile +++ b/Makefile @@ -411,8 +411,8 @@ ensure-prisma-present: echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \ fi \ else \ - echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \ - git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ + echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \ + git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \ fi; # Quick schema validation of whatever you have in the dev_datamodel.prisma file. diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index fe7751ddf373..869d242d066e 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -15,6 +15,7 @@ use crate::{ }; use async_trait::async_trait; use futures::lock::Mutex; +use std::borrow::Cow; use std::{ convert::TryFrom, future::Future, @@ -48,9 +49,7 @@ impl TransactionCapable for Mssql { let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } } @@ -244,8 +243,33 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}")) + } else { + Cow::Borrowed("BEGIN TRAN") + } + } + + /// Statement to commit a transaction + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + if depth > 1 { + Cow::Owned("".to_string()) + } else { + Cow::Borrowed("COMMIT") + } + } + + /// Statement to rollback a transaction + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}")) + } else { + Cow::Borrowed("ROLLBACK") + } } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 2c8a757a48e7..94dd25f5c0e2 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -21,6 +21,7 @@ use mysql_async::{ self as my, prelude::{Query as _, Queryable as _}, }; +use std::borrow::Cow; use std::{ future::Future, sync::atomic::{AtomicBool, Ordering}, @@ -347,4 +348,31 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("BEGIN") + } + } + + /// Statement to commit a transaction + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("COMMIT") + } + } + + /// Statement to rollback a transaction + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) + } else { + Cow::Borrowed("ROLLBACK") + } + } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index eb6618ce9dc7..8cd1e1a07577 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -29,6 +29,7 @@ use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; use postgres_types::{Kind as PostgresKind, Type as PostgresType}; use prisma_metrics::WithMetricsInstrumentation; +use std::borrow::Cow; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ fmt::{Debug, Display}, @@ -806,6 +807,33 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("BEGIN") + } + } + + /// Statement to commit a transaction + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("COMMIT") + } + } + + /// Statement to rollback a transaction + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("ROLLBACK") + } + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 5f0fd54dad6b..b34e42866d85 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use super::{DescribedQuery, IsolationLevel, ResultSet, Transaction}; use crate::ast::*; use async_trait::async_trait; @@ -90,8 +92,30 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("BEGIN") + } + } + + /// Statement to commit a transaction + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("COMMIT") + } + } + + /// Statement to rollback a transaction + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("ROLLBACK") + } } /// Sets the transaction isolation level to given value. @@ -123,7 +147,7 @@ macro_rules! impl_default_TransactionCapable { let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 2d738a7f087f..d4cf610af6f5 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -17,6 +17,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; +use std::borrow::Cow; use std::convert::TryFrom; use tokio::sync::Mutex; @@ -183,12 +184,35 @@ impl Queryable for Sqlite { false } - fn begin_statement(&self) -> &'static str { + /// Statement to begin a transaction + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { // From https://sqlite.org/isolation.html: // `BEGIN IMMEDIATE` avoids possible `SQLITE_BUSY_SNAPSHOT` that arise when another connection jumps ahead in line. // The BEGIN IMMEDIATE command goes ahead and starts a write transaction, and thus blocks all other writers. // If the BEGIN IMMEDIATE operation succeeds, then no subsequent operations in that transaction will ever fail with an SQLITE_BUSY error. - "BEGIN IMMEDIATE" + if depth > 1 { + Cow::Owned(format!("SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("BEGIN IMMEDIATE") + } + } + + /// Statement to commit a transaction + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}")) + } else { + Cow::Borrowed("COMMIT") + } + } + + /// Statement to rollback a transaction + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + if depth > 1 { + Cow::Owned(format!("ROLLBACK TO savepoint{depth}")) + } else { + Cow::Borrowed("ROLLBACK") + } } } diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index 599efe1d99fc..461ce6610ee2 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -1,5 +1,3 @@ -use std::{fmt, str::FromStr}; - use async_trait::async_trait; use prisma_metrics::guards::GaugeGuard; @@ -8,14 +6,22 @@ use crate::{ ast::*, error::{Error, ErrorKind}, }; +use std::{ + fmt, + str::FromStr, + sync::{Arc, Mutex}, +}; #[async_trait] pub trait Transaction: Queryable { + /// Start a new transaction or nested transaction via savepoint. + async fn begin(&mut self) -> crate::Result<()>; + /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -36,18 +42,19 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, gauge: GaugeGuard, } impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, - begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { + let mut this = Self { inner, gauge: GaugeGuard::increment("prisma_client_queries_active"), + depth: Arc::new(Mutex::new(0)), }; if tx_opts.isolation_first { @@ -56,7 +63,7 @@ impl<'a> DefaultTransaction<'a> { } } - inner.raw_cmd(begin_stmt).await?; + this.begin().await?; if !tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -72,20 +79,71 @@ impl<'a> DefaultTransaction<'a> { #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { - /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { - self.gauge.decrement(); - self.inner.raw_cmd("COMMIT").await?; + async fn begin(&mut self) -> crate::Result<()> { + let current_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth += 1; + *depth + }; + + let begin_statement = self.inner.begin_statement(current_depth); + + self.inner.raw_cmd(&begin_statement).await?; Ok(()) } + /// Commit the changes to the database and consume the transaction. + async fn commit(&mut self) -> crate::Result { + // Lock the mutex and get the depth value + let depth_val = { + let depth = self.depth.lock().unwrap(); + *depth + }; + + // Perform the asynchronous operation without holding the lock + let commit_statement = self.inner.commit_statement(depth_val); + self.inner.raw_cmd(&commit_statement).await?; + + // Lock the mutex again to modify the depth + let new_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + *depth + }; + + if new_depth == 0 { + self.gauge.decrement(); + } + + Ok(new_depth) + } + /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { - self.gauge.decrement(); - self.inner.raw_cmd("ROLLBACK").await?; + async fn rollback(&mut self) -> crate::Result { + // Lock the mutex and get the depth value + let depth_val = { + let depth = self.depth.lock().unwrap(); + *depth + }; - Ok(()) + // Perform the asynchronous operation without holding the lock + let rollback_statement = self.inner.rollback_statement(depth_val); + + self.inner.raw_cmd(&rollback_statement).await?; + + // Lock the mutex again to modify the depth + let new_depth = { + let mut depth = self.depth.lock().unwrap(); + *depth -= 1; + *depth + }; + + if new_depth == 0 { + self.gauge.decrement(); + } + + Ok(new_depth) } fn as_queryable(&self) -> &dyn Queryable { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index bf4d50eeea87..22bea71778a3 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -3,6 +3,7 @@ use std::future::Future; use async_trait::async_trait; use mobc::{Connection as MobcPooled, Manager}; use prisma_metrics::WithMetricsInstrumentation; +use std::borrow::Cow; use tracing_futures::WithSubscriber; #[cfg(feature = "mssql-native")] @@ -71,8 +72,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.begin_statement(depth) + } + + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.commit_statement(depth) + } + + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.rollback_statement(depth) } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 13be8c4bc857..9953d021ae02 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,7 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; -use std::{fmt, sync::Arc}; +use std::{borrow::Cow, fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] use std::convert::TryFrom; @@ -238,8 +238,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.begin_statement(depth) + } + + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.commit_statement(depth) + } + + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.rollback_statement(depth) } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 6e83297a9a75..5c7c96360529 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,21 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + tx.begin().await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + // Open another nested transaction + tx.begin().await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx.commit().await?; + + tx.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 399866bd4a3b..424d2a0348ea 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index 100828697046..972dcbf59146 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -10,7 +10,7 @@ mod interactive_tx { #[connector_test] async fn basic_commit_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -37,7 +37,7 @@ mod interactive_tx { #[connector_test] async fn basic_rollback_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -65,7 +65,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one second. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -110,7 +110,7 @@ mod interactive_tx { #[connector_test] async fn no_auto_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -137,7 +137,7 @@ mod interactive_tx { #[connector_test(only(Postgres))] async fn raw_queries(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -166,7 +166,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_success(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -192,7 +192,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -218,7 +218,7 @@ mod interactive_tx { #[connector_test(exclude(Sqlite("cfd1")))] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // One dup key, will cause failure of the batch. @@ -263,7 +263,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -332,10 +332,10 @@ mod interactive_tx { #[connector_test(exclude(Sqlite))] async fn multiple_tx(mut runner: Runner) -> TestResult<()> { // First transaction. - let tx_id_a = runner.start_tx(2000, 2000, None).await?; + let tx_id_a = runner.start_tx(2000, 2000, None, None).await?; // Second transaction. - let tx_id_b = runner.start_tx(2000, 2000, None).await?; + let tx_id_b = runner.start_tx(2000, 2000, None, None).await?; // Execute on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -383,10 +383,10 @@ mod interactive_tx { ); // First transaction. - let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Second transaction. - let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Read on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -425,7 +425,7 @@ mod interactive_tx { #[connector_test] async fn double_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -460,9 +460,81 @@ mod interactive_tx { Ok(()) } + #[connector_test(only(Postgres))] + async fn nested_commit_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + let res = runner.commit_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + Ok(()) + } + + #[connector_test(only(Postgres))] + async fn nested_commit_rollback_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + // Now rollback the outer transaction + let res = runner.rollback_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + // Assert that no records were written to the DB + let result_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(result_tx_id.clone()); + insta::assert_snapshot!( + run_query!(&runner, r#"query { findManyTestModel { id field }}"#), + @r###"{"data":{"findManyTestModel":[]}}"### + ); + let _ = runner.commit_tx(result_tx_id).await?; + + Ok(()) + } + #[connector_test] async fn double_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -499,7 +571,7 @@ mod interactive_tx { #[connector_test] async fn commit_after_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -536,7 +608,7 @@ mod interactive_tx { #[connector_test] async fn rollback_after_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -582,7 +654,9 @@ mod itx_isolation { // All (SQL) connectors support serializable. #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Serializable".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -604,7 +678,9 @@ mod itx_isolation { #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -615,13 +691,17 @@ mod itx_isolation { #[connector_test(only(Postgres))] async fn spacing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("Repeatable Read".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; assert!(res.is_ok()); - let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned())).await?; + let tx_id = runner + .start_tx(5000, 5000, Some("RepeatableRead".to_owned()), None) + .await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -632,7 +712,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb))] async fn invalid_isolation(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected invalid isolation level string to throw an error, but it succeeded instead."), @@ -645,7 +725,7 @@ mod itx_isolation { // Mongo doesn't support isolation levels. #[connector_test(only(MongoDb))] async fn mongo_failure(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected mongo to throw an unsupported error, but it succeeded instead."), @@ -666,7 +746,7 @@ mod itx_isolation { set.spawn({ let runner = Arc::clone(&runner); async move { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner .query_in_tx( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 323f162a2111..d634102ee281 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -46,7 +46,7 @@ mod metrics { #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -63,7 +63,7 @@ mod metrics { let active_transactions = utils::metrics::get_gauge(&json, PRISMA_CLIENT_QUERIES_ACTIVE); assert_eq!(active_transactions, 0.0); - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs index 907aae408bf9..80951461410c 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs @@ -96,7 +96,7 @@ mod prisma_11750 { } async fn update_user(runner: Arc, new_email: &str) -> TestResult<()> { - let tx_id = runner.start_tx(2000, 25, None).await?; + let tx_id = runner.start_tx(2000, 25, None, None).await?; let result = runner .query_in_tx( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs index a9b6c4395760..49ea6597ff6b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs @@ -90,7 +90,7 @@ mod mongodb { } async fn start_itx(runner: &mut Runner) -> TestResult { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); Ok(tx_id) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index 7ed3cb9a8598..cf8bcdaf1ad8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -83,9 +83,12 @@ impl Actor { response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = - with_observability(runner.start_tx(10000, 10000, None), log_tx.clone(), recorder.clone()) - .await; + let response = with_observability( + runner.start_tx(10000, 10000, None, None), + log_tx.clone(), + recorder.clone(), + ) + .await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs index 313a29cdacf4..0b1e3244e420 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_engines_4286.rs @@ -5,7 +5,9 @@ mod sqlite { #[connector_test] async fn close_tx_on_error(runner: Runner) -> TestResult<()> { // Try to open a transaction with unsupported isolation error in SQLite. - let result = runner.start_tx(2000, 5000, Some("ReadUncommitted".to_owned())).await; + let result = runner + .start_tx(2000, 5000, Some("ReadUncommitted".to_owned()), None) + .await; assert!(result.is_err()); // Without the changes from https://github.com/prisma/prisma-engines/pull/4286 or @@ -16,7 +18,7 @@ mod sqlite { // IMMEDIATE if we had control over SQLite transaction type here, as that would not rely on // both transactions using the same connection if we were to pool multiple SQLite // connections in the future. - let tx = runner.start_tx(2000, 5000, None).await?; + let tx = runner.start_tx(2000, 5000, None, None).await?; runner.rollback_tx(tx).await?.unwrap(); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index de8ee9bd33be..ed3f6de13740 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -517,8 +517,9 @@ impl Runner { max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, + new_tx_id: Option, ) -> TestResult { - let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level); + let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level, new_tx_id); match &self.executor { RunnerExecutor::Builtin(executor) => { let id = executor diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 6045d06b442d..163c78c8a393 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -44,17 +44,21 @@ impl<'conn> MongoDbTransaction<'conn> { #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { - async fn commit(&mut self) -> connector_interface::Result<()> { + async fn begin(&mut self) -> connector_interface::Result<()> { + Ok(()) + } + + async fn commit(&mut self) -> connector_interface::Result { self.gauge.decrement(); utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } - async fn rollback(&mut self) -> connector_interface::Result<()> { + async fn rollback(&mut self) -> connector_interface::Result { self.gauge.decrement(); self.connection @@ -63,7 +67,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } async fn version(&self) -> Option { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index 05e8f1e1098f..fccb9aaaccfa 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -36,8 +36,9 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { - async fn commit(&mut self) -> crate::Result<()>; - async fn rollback(&mut self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result; async fn version(&self) -> Option; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 387b18f63ee2..ce7e102b92ad 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -38,19 +38,26 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { - async fn commit(&mut self) -> connector::Result<()> { + async fn begin(&mut self) -> connector::Result<()> { + catch(&self.connection_info, async { + self.inner.begin().await.map_err(SqlError::from) + }) + .await + } + + async fn commit(&mut self) -> connector::Result { catch(&self.connection_info, async { self.inner.commit().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result<()> { + async fn rollback(&mut self) -> connector::Result { catch(&self.connection_info, async { let res = self.inner.rollback().await.map_err(SqlError::from); match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), + Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(0), _ => res, } }) diff --git a/query-engine/core/src/executor/interpreting_executor.rs b/query-engine/core/src/executor/interpreting_executor.rs index 2e391461c718..0f73c99d2ddc 100644 --- a/query-engine/core/src/executor/interpreting_executor.rs +++ b/query-engine/core/src/executor/interpreting_executor.rs @@ -189,7 +189,7 @@ where self.itx_manager.commit_tx(&tx_id).await } - async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()> { + async fn rollback_tx(&self, tx_id: TxId) -> crate::Result { self.itx_manager.rollback_tx(&tx_id).await } } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index c7846f7ff7cb..0146a94eb93e 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -73,17 +73,21 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option, } impl TransactionOptions { - pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option) -> Self { + pub fn new( + max_acquisition_millis: u64, + valid_for_millis: u64, + isolation_level: Option, + new_tx_id: Option, + ) -> Self { Self { max_acquisition_millis, valid_for_millis, isolation_level, - new_tx_id: None, + new_tx_id, } } @@ -114,5 +118,5 @@ pub trait TransactionManager { async fn commit_tx(&self, tx_id: TxId) -> crate::Result<()>; /// Rolls back a transaction. - async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()>; + async fn rollback_tx(&self, tx_id: TxId) -> crate::Result; } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index edf70b3d0109..cb947683dca0 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -71,15 +71,39 @@ impl<'a> ITXServer<'a> { let _ = op.respond_to.send(TxOpResponse::Batch(result)); RunState::Continue } + TxOpRequestMsg::Begin => { + let _result = self.begin().await; + let _ = op.respond_to.send(TxOpResponse::Begin(())); + RunState::Continue + } TxOpRequestMsg::Commit => { let resp = self.commit().await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; + let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } TxOpRequestMsg::Rollback => { let resp = self.rollback(false).await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } } } @@ -123,32 +147,46 @@ impl<'a> ITXServer<'a> { .await } - pub(crate) async fn commit(&mut self) -> crate::Result<()> { + pub(crate) async fn begin(&mut self) -> crate::Result<()> { if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; + trace!("[{}] beginning.", self.id.to_string()); + open_tx.begin().await?; } Ok(()) } - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub(crate) async fn commit(&mut self) -> crate::Result { + if let CachedTx::Open(_) = self.cached_tx { + let open_tx = self.cached_tx.as_open()?; + trace!("[{}] committing.", self.id.to_string()); + let depth = open_tx.commit().await?; + if depth == 0 { + self.cached_tx = CachedTx::Committed; + } + return Ok(depth); + } + + Ok(0) + } + + pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result { debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; + let depth = open_tx.rollback().await?; if was_timeout { trace!("[{}] Expired Rolling back", self.id.to_string()); self.cached_tx = CachedTx::Expired; - } else { + } else if depth == 0 { self.cached_tx = CachedTx::RolledBack; trace!("[{}] Rolling back", self.id.to_string()); } + return Ok(depth); } - Ok(()) + Ok(0) } pub(crate) fn name(&self) -> String { @@ -163,7 +201,12 @@ pub struct ITXClient { } impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { + pub async fn begin(&self) -> crate::Result<()> { + self.send_and_receive(TxOpRequestMsg::Begin).await?; + Ok(()) + } + + pub(crate) async fn commit(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; if let TxOpResponse::Committed(resp) = msg { @@ -174,7 +217,7 @@ impl ITXClient { } } - pub(crate) async fn rollback(&self) -> crate::Result<()> { + pub(crate) async fn rollback(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; if let TxOpResponse::RolledBack(resp) = msg { diff --git a/query-engine/core/src/interactive_transactions/manager.rs b/query-engine/core/src/interactive_transactions/manager.rs index d9873c4383a7..654a9696e8f5 100644 --- a/query-engine/core/src/interactive_transactions/manager.rs +++ b/query-engine/core/src/interactive_transactions/manager.rs @@ -109,24 +109,30 @@ impl ItxManager { isolation_level: Option, timeout: Duration, ) -> crate::Result<()> { - // This task notifies the task spawned in `new()` method that the timeout for this - // transaction has expired. - crosstarget_utils::task::spawn({ - let timeout_sender = self.timeout_sender.clone(); - let tx_id = tx_id.clone(); - async move { - crosstarget_utils::time::sleep(timeout).await; - timeout_sender.send(tx_id).expect("receiver must exist"); - } - }); + // Only create a client if there is no client for this transaction yet. + // otherwise, begin a new transaction/savepoint for the existing client. + if self.transactions.read().await.contains_key(&tx_id) { + let _ = self.get_transaction(&tx_id, "begin").await?.lock().await.begin().await; + } else { + // This task notifies the task spawned in `new()` method that the timeout for this + // transaction has expired. + crosstarget_utils::task::spawn({ + let timeout_sender = self.timeout_sender.clone(); + let tx_id = tx_id.clone(); + async move { + crosstarget_utils::time::sleep(timeout).await; + timeout_sender.send(tx_id).expect("receiver must exist"); + } + }); - let transaction = - InteractiveTransaction::new(tx_id.clone(), conn, timeout, query_schema, isolation_level).await?; + let transaction = + InteractiveTransaction::new(tx_id.clone(), conn, timeout, query_schema, isolation_level).await?; - self.transactions - .write() - .await - .insert(tx_id, Arc::new(Mutex::new(transaction))); + self.transactions + .write() + .await + .insert(tx_id, Arc::new(Mutex::new(transaction))); + } Ok(()) } @@ -181,7 +187,7 @@ impl ItxManager { self.get_transaction(tx_id, "commit").await?.lock().await.commit().await } - pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { + pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result { self.get_transaction(tx_id, "rollback") .await? .lock() diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index 009cab37ccfd..57058af4d9e6 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,5 +1,5 @@ use derive_more::Display; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; mod error; mod manager; @@ -10,7 +10,7 @@ pub use error::*; pub(crate) use manager::*; pub(crate) use transaction::*; -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Display)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize, Display)] #[display(fmt = "{}", _0)] pub struct TxId(String); diff --git a/query-engine/core/src/interactive_transactions/transaction.rs b/query-engine/core/src/interactive_transactions/transaction.rs index 4e84155ad78e..505ff1b622ee 100644 --- a/query-engine/core/src/interactive_transactions/transaction.rs +++ b/query-engine/core/src/interactive_transactions/transaction.rs @@ -197,28 +197,50 @@ impl InteractiveTransaction { }) } - pub async fn commit(&mut self) -> crate::Result<()> { - tx_timeout!(self, "commit", async { + pub async fn begin(&mut self) -> crate::Result<()> { + tx_timeout!(self, "begin", async { let name = self.name(); - let conn = self.state.as_open("commit")?; - let span = info_span!("prisma:engine:itx_commit", user_facing = true); + let conn = self.state.as_open("begin")?; + let span = info_span!("prisma:engine:itx_begin", user_facing = true); - if let Err(err) = conn.commit().instrument(span).await { - error!(?err, ?name, "transaction failed to commit"); - // We don't know if the transaction was committed or not. Because of that, we cannot - // leave it in "open" state. We attempt to rollback to get the transaction into a - // known state. + if let Err(err) = conn.begin().instrument(span).await { + error!(?err, ?name, "transaction failed to begin"); let _ = self.rollback(false).await; Err(err.into()) } else { - debug!(?name, "transaction committed"); - self.state = TransactionState::Committed; + debug!(?name, "transaction started"); Ok(()) } }) } - pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub async fn commit(&mut self) -> crate::Result<()> { + tx_timeout!(self, "commit", async { + let name = self.name(); + let conn = self.state.as_open("commit")?; + let span = info_span!("prisma:engine:itx_commit", user_facing = true); + + match conn.commit().instrument(span).await { + Ok(depth) => { + debug!(?depth, ?name, "transaction committed"); + if depth == 0 { + self.state = TransactionState::Committed; + } + Ok(()) + } + Err(err) => { + error!(?err, ?name, "transaction failed to commit"); + // We don't know if the transaction was committed or not. Because of that, we cannot + // leave it in "open" state. We attempt to rollback to get the transaction into a + // known state. + let _ = self.rollback(false).await; + Err(err.into()) + } + } + }) + } + + pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result { let name = self.name(); let conn = self.state.as_open("rollback")?; let span = info_span!("prisma:engine:itx_rollback", user_facing = true); diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index cf78a4cbb88d..4c917797d5ad 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -49,6 +49,9 @@ pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, + /// begin transaction + pub begin: AdapterMethod<(), ()>, + /// commit transaction commit: AdapterMethod<(), ()>, @@ -133,11 +136,13 @@ impl TransactionContextProxy { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> JsResult { let commit = get_named_property(js_transaction, "commit")?; + let begin = get_named_property(js_transaction, "begin")?; let rollback = get_named_property(js_transaction, "rollback")?; let options = get_named_property(js_transaction, "options")?; let options = from_js_value::(options); Ok(Self { + begin, commit, rollback, options, @@ -149,6 +154,10 @@ impl TransactionProxy { &self.options } + pub fn begin(&self) -> UnsafeFuture> + '_> { + UnsafeFuture(self.begin.call_as_async(())) + } + /// Commits the transaction via the driver adapter. /// /// ## Cancellation safety diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index 8aa7579762f6..d159a2f0de6c 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -328,16 +328,18 @@ impl JsQueryable { } // 3. Spawn a transaction from the context. - let tx = tx_ctx.start_transaction().await?; + let mut tx = tx_ctx.start_transaction().await?; - let begin_stmt = tx.begin_statement(); + tx.depth += 1; + + let begin_stmt = tx.begin_statement(tx.depth); let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(begin_stmt).await?; + tx.raw_cmd(&begin_stmt).await?; } // 4. Set the isolation level (if specified) if we didn't do it before. diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 3a1167159ae5..dfc3a920c90a 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,4 +1,4 @@ -use std::future::Future; +use std::{borrow::Cow, future::Future}; use async_trait::async_trait; use prisma_metrics::gauge; @@ -86,11 +86,16 @@ impl Queryable for JsTransactionContext { pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: u32, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + depth: 0, + } } pub fn options(&self) -> &TransactionOptions { @@ -112,36 +117,60 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn begin(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction gauge!("prisma_client_queries_active").decrement(1.0); - let commit_stmt = "COMMIT"; + self.depth += 1; + + let begin_stmt = self.begin_statement(self.depth); if self.options().use_phantom_query { - let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); self.raw_phantom_cmd(commit_stmt.as_str()).await?; } else { - self.inner.raw_cmd(commit_stmt).await?; + self.inner.raw_cmd(&begin_stmt).await?; } - UnsafeFuture(self.tx_proxy.commit()).await + UnsafeFuture(self.tx_proxy.begin()).await } - - async fn rollback(&self) -> quaint::Result<()> { + async fn commit(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction gauge!("prisma_client_queries_active").decrement(1.0); - let rollback_stmt = "ROLLBACK"; + let commit_stmt = self.commit_statement(self.depth); if self.options().use_phantom_query { - let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); + let commit_stmt = JsBaseQueryable::phantom_query_message(&commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(&commit_stmt).await?; + } + + let _ = UnsafeFuture(self.tx_proxy.commit()).await; + + // Modify the depth value + self.depth -= 1; + + Ok(self.depth) + } + + async fn rollback(&mut self) -> quaint::Result { + let rollback_stmt = self.rollback_statement(self.depth); + + if self.options().use_phantom_query { + let rollback_stmt = JsBaseQueryable::phantom_query_message(&rollback_stmt); self.raw_phantom_cmd(rollback_stmt.as_str()).await?; } else { - self.inner.raw_cmd(rollback_stmt).await?; + self.inner.raw_cmd(&rollback_stmt).await?; } - UnsafeFuture(self.tx_proxy.rollback()).await + let _ = UnsafeFuture(self.tx_proxy.rollback()).await; + + // Modify the depth value + self.depth -= 1; + + Ok(self.depth) } fn as_queryable(&self) -> &dyn Queryable { @@ -198,6 +227,18 @@ impl Queryable for JsTransaction { fn requires_isolation_first(&self) -> bool { self.inner.requires_isolation_first() } + + fn begin_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.begin_statement(depth) + } + + fn commit_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.commit_statement(depth) + } + + fn rollback_statement(&self, depth: u32) -> Cow<'static, str> { + self.inner.rollback_statement(depth) + } } #[cfg(target_arch = "wasm32")] diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 699d4d78d11b..3a8c2c4338f9 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -256,7 +256,13 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let full_body = hyper::body::to_bytes(body_start).await?; let tx_opts = match serde_json::from_slice::(full_body.as_ref()) { - Ok(opts) => opts.with_new_transaction_id(), + Ok(opts) => { + if opts.new_tx_id.is_none() { + opts.with_new_transaction_id() + } else { + opts + } + } Err(_) => { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST)