diff --git a/Cargo.lock b/Cargo.lock index 61011e5..221f501 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -779,6 +779,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -1923,6 +1929,7 @@ dependencies = [ "sqllogictest", "sqllogictest-engines", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -2212,6 +2219,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/sqllogictest-bin/Cargo.toml b/sqllogictest-bin/Cargo.toml index f8aa4b4..6587463 100644 --- a/sqllogictest-bin/Cargo.toml +++ b/sqllogictest-bin/Cargo.toml @@ -33,6 +33,7 @@ tokio = { version = "1", features = [ "fs", "process", ] } +tokio-util = { version = "0.7.12", features = ["rt"] } fs-err = "3.0.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = "0.1" diff --git a/sqllogictest-bin/src/main.rs b/sqllogictest-bin/src/main.rs index 50ec5d1..2c6dd46 100644 --- a/sqllogictest-bin/src/main.rs +++ b/sqllogictest-bin/src/main.rs @@ -20,7 +20,7 @@ use sqllogictest::{ default_column_validator, default_normalizer, default_validator, update_record_with_output, AsyncDB, Injected, MakeConnection, Record, Runner, }; -use utils::AbortOnDropHandle; +use tokio_util::task::AbortOnDropHandle; #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ValueEnum)] #[must_use] @@ -323,7 +323,7 @@ async fn run_parallel( let engine = engine.clone(); let labels = labels.to_vec(); async move { - let (buf, res) = AbortOnDropHandle(tokio::spawn(async move { + let (buf, res) = AbortOnDropHandle::new(tokio::spawn(async move { let mut buf = vec![]; let res = connect_and_run_test_file(&mut buf, filename, &engine, config, &labels) @@ -835,28 +835,3 @@ async fn update_record( Ok(()) } - -mod utils { - use std::future::Future; - use std::pin::Pin; - use std::task::{Context, Poll}; - - use tokio::task::{JoinError, JoinHandle}; - - /// A wrapper around a [`tokio::task::JoinHandle`], which aborts the task when it is dropped. - pub struct AbortOnDropHandle(pub JoinHandle); - - impl Drop for AbortOnDropHandle { - fn drop(&mut self) { - self.0.abort(); - } - } - - impl Future for AbortOnDropHandle { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) - } - } -}