diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 252b00ca0adc..6f1c934c0855 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1208,6 +1208,7 @@ dependencies = [ "parking_lot", "rand", "tempfile", + "tokio", "url", ] @@ -1299,6 +1300,7 @@ dependencies = [ "pin-project-lite", "rand", "tokio", + "tokio-stream", "uuid", ] diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 85b97aac037d..a33973790c60 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -915,6 +915,11 @@ impl DFField { self.field = f.into(); self } + + pub fn with_qualifier(mut self, qualifier: impl Into) -> Self { + self.qualifier = Some(qualifier.into()); + self + } } impl From for DFField { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d696c55a8c13..3ca16437bcb4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,7 +19,8 @@ use std::collections::HashMap; use std::fmt::Write; -use std::sync::Arc; +use std::sync::atomic::AtomicI32; +use std::sync::{Arc, OnceLock}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::avro::AvroFormat; @@ -47,6 +48,7 @@ use crate::physical_expr::create_physical_expr; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; +use crate::physical_plan::continuance::ContinuanceExec; use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, PhysicalSortExpr}; @@ -58,6 +60,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::recursive_query::RecursiveQueryExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; @@ -87,8 +90,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, NamedRelation, RecursiveQuery, ScalarFunctionDefinition, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -450,11 +453,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result> { + reset_recursive_cte_physical_plan_branch_number(); + match self.handle_explain(logical_plan, session_state).await? { Some(plan) => Ok(plan), None => { let plan = self - .create_initial_plan(logical_plan, session_state) + .create_initial_plan(logical_plan, session_state, None) .await?; self.optimize_internal(plan, session_state, |_, _| {}) } @@ -485,6 +490,21 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { } } +static RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH: OnceLock = OnceLock::new(); + +fn new_recursive_cte_physical_plan_branch_number() -> u32 { + let counter = RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH + .get_or_init(|| AtomicI32::new(0)) + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + counter as u32 +} + +fn reset_recursive_cte_physical_plan_branch_number() { + RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH + .get_or_init(|| AtomicI32::new(0)) + .store(0, std::sync::atomic::Ordering::SeqCst); +} + impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to /// plan user-defined logical nodes [`LogicalPlan::Extension`]. @@ -505,6 +525,7 @@ impl DefaultPhysicalPlanner { &'a self, logical_plans: impl IntoIterator + Send + 'a, session_state: &'a SessionState, + ctx: Option<&'a String>, ) -> BoxFuture<'a, Result>>> { async move { // First build futures with as little references as possible, then performing some stream magic. @@ -517,7 +538,7 @@ impl DefaultPhysicalPlanner { .into_iter() .enumerate() .map(|(idx, lp)| async move { - let plan = self.create_initial_plan(lp, session_state).await?; + let plan = self.create_initial_plan(lp, session_state, ctx).await?; Ok((idx, plan)) as Result<_> }) .collect::>(); @@ -546,6 +567,7 @@ impl DefaultPhysicalPlanner { &'a self, logical_plan: &'a LogicalPlan, session_state: &'a SessionState, + ctx: Option<&'a String>, ) -> BoxFuture<'a, Result>> { async move { let exec_plan: Result> = match logical_plan { @@ -570,7 +592,7 @@ impl DefaultPhysicalPlanner { single_file_output, copy_options, }) => { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); @@ -618,7 +640,7 @@ impl DefaultPhysicalPlanner { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; provider.insert_into(session_state, input_exec, false).await } else { return exec_err!( @@ -635,7 +657,7 @@ impl DefaultPhysicalPlanner { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; provider.insert_into(session_state, input_exec, true).await } else { return exec_err!( @@ -676,7 +698,7 @@ impl DefaultPhysicalPlanner { ); } - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; // at this moment we are guaranteed by the logical planner // to have all the window_expr to have equal sort key @@ -772,7 +794,7 @@ impl DefaultPhysicalPlanner { .. }) => { // Initially need to perform the aggregate and then merge the partitions - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); @@ -846,7 +868,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Projection(Projection { input, expr, .. }) => { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = input.as_ref().schema(); let physical_exprs = expr @@ -898,7 +920,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Filter(filter) => { - let physical_input = self.create_initial_plan(&filter.input, session_state).await?; + let physical_input = self.create_initial_plan(&filter.input, session_state, ctx).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = filter.input.schema(); @@ -912,8 +934,8 @@ impl DefaultPhysicalPlanner { let filter = FilterExec::try_new(runtime_expr, physical_input)?; Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, .. }) => { - let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; + LogicalPlan::Union(Union { inputs, schema: _ }) => { + let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state, ctx).await?; Ok(Arc::new(UnionExec::new(physical_plans))) } @@ -921,7 +943,7 @@ impl DefaultPhysicalPlanner { input, partitioning_scheme, }) => { - let physical_input = self.create_initial_plan(input, session_state).await?; + let physical_input = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -952,7 +974,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { - let physical_input = self.create_initial_plan(input, session_state).await?; + let physical_input = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = expr @@ -1043,12 +1065,12 @@ impl DefaultPhysicalPlanner { }; return self - .create_initial_plan(&join_plan, session_state) + .create_initial_plan(&join_plan, session_state, ctx) .await; } // All equi-join keys are columns now, create physical join plan - let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?; + let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?; let [physical_left, physical_right]: [Arc; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?; let left_df_schema = left.schema(); let right_df_schema = right.schema(); @@ -1183,7 +1205,7 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?; + let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?; let [left, right]: [Arc; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?; Ok(Arc::new(CrossJoinExec::new(left, right))) } @@ -1201,10 +1223,10 @@ impl DefaultPhysicalPlanner { SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { - self.create_initial_plan(input, session_state).await + self.create_initial_plan(input, session_state, ctx).await } LogicalPlan::Limit(Limit { input, skip, fetch, .. }) => { - let input = self.create_initial_plan(input, session_state).await?; + let input = self.create_initial_plan(input, session_state, ctx).await?; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -1222,7 +1244,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } LogicalPlan::Unnest(Unnest { input, column, schema, options }) => { - let input = self.create_initial_plan(input, session_state).await?; + let input = self.create_initial_plan(input, session_state, ctx).await?; let column_exec = schema.index_of_column(column) .map(|idx| Column::new(&column.name, idx))?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); @@ -1275,7 +1297,7 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Analyze must be root of the plan" ), LogicalPlan::Extension(e) => { - let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state).await?; + let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state, ctx).await?; let mut maybe_plan = None; for planner in &self.extension_planners { @@ -1311,6 +1333,34 @@ impl DefaultPhysicalPlanner { Ok(plan) } } + LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct }) => { + let name = format!("{}-{}", name, new_recursive_cte_physical_plan_branch_number()); + + let ctx = Some(&name); + + let static_term = self.create_initial_plan(static_term, session_state, ctx).await?; + let recursive_term = self.create_initial_plan(recursive_term, session_state, ctx).await?; + + Ok(Arc::new(RecursiveQueryExec::new(name.clone(), static_term, recursive_term, *is_distinct))) + } + LogicalPlan::NamedRelation(NamedRelation {schema, ..}) => { + let name = ctx.expect("NamedRelation must have a context that contains the recursive query's branch name"); + // Named relations is how we represent access to any sort of dynamic data provider. They + // differ from tables in the sense that they can start existing dynamically during the + // execution of a query and then disappear before it even finishes. + // + // This system allows us to replicate the tricky behavior of classical databases where a + // temporary "working table" (as it is called in Postgres) can be used when dealing with + // complex operations (such as recursive CTEs) and then can be dropped. Since DataFusion + // at its core is heavily stream-based and vectorized, we try to avoid using 'real' tables + // and let the streams take care of the data flow in this as well. + + // Since the actual "input"'s will be only available to us at runtime (through task context) + // we can't really do any sort of meaningful validation here. + let schema = SchemaRef::new(schema.as_ref().to_owned().into()); + Ok(Arc::new(ContinuanceExec::new(name.clone(), schema))) + } + }; exec_plan }.boxed() @@ -1841,6 +1891,8 @@ impl DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result>> { + reset_recursive_cte_physical_plan_branch_number(); + if let LogicalPlan::Explain(e) = logical_plan { use PlanType::*; let mut stringified_plans = vec![]; @@ -1856,7 +1908,7 @@ impl DefaultPhysicalPlanner { if !config.logical_plan_only && e.logical_optimization_succeeded { match self - .create_initial_plan(e.plan.as_ref(), session_state) + .create_initial_plan(e.plan.as_ref(), session_state, None) .await { Ok(input) => { diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index e9bb87e9f8ac..374f731c5155 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -45,4 +45,12 @@ object_store = { workspace = true } parking_lot = { workspace = true } rand = { workspace = true } tempfile = { workspace = true } +tokio = { version = "1.28", features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", + "fs", + "parking_lot", +] } url = { workspace = true } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 52c183b1612c..31a4df246946 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -33,6 +33,15 @@ use crate::{ runtime_env::{RuntimeConfig, RuntimeEnv}, }; +use arrow::record_batch::RecordBatch; +// use futures::channel::mpsc::Receiver as SingleChannelReceiver; +use tokio::sync::mpsc::Receiver as SingleChannelReceiver; +// use futures::lock::Mutex; +use parking_lot::Mutex; +// use futures:: + +type RelationHandler = SingleChannelReceiver>; + /// Task Execution Context /// /// A [`TaskContext`] contains the state available during a single @@ -56,6 +65,8 @@ pub struct TaskContext { window_functions: HashMap>, /// Runtime environment associated with this task context runtime: Arc, + /// Registered relation handlers + relation_handlers: Mutex>, } impl Default for TaskContext { @@ -72,6 +83,7 @@ impl Default for TaskContext { aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime: Arc::new(runtime), + relation_handlers: Mutex::new(HashMap::new()), } } } @@ -99,6 +111,7 @@ impl TaskContext { aggregate_functions, window_functions, runtime, + relation_handlers: Mutex::new(HashMap::new()), } } @@ -171,6 +184,34 @@ impl TaskContext { self.runtime = runtime; self } + + /// Register a new relation handler. If a handler with the same name already exists + /// this function will return an error. + pub fn push_relation_handler( + &self, + name: String, + handler: RelationHandler, + ) -> Result<()> { + let mut handlers = self.relation_handlers.lock(); + if handlers.contains_key(&name) { + return Err(DataFusionError::Internal(format!( + "Relation handler {} already registered", + name + ))); + } + handlers.insert(name, handler); + Ok(()) + } + + /// Retrieve the relation handler for the given name. It will remove the handler from + /// the storage if it exists, and return it as is. + pub fn pop_relation_handler(&self, name: String) -> Result { + let mut handlers = self.relation_handlers.lock(); + + handlers.remove(name.as_str()).ok_or_else(|| { + DataFusionError::Internal(format!("Relation handler {} not registered", name)) + }) + } } impl FunctionRegistry for TaskContext { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 847fbbbf61c7..d23a9f7b8178 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -55,6 +55,8 @@ use datafusion_common::{ ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; +use super::plan::{NamedRelation, RecursiveQuery}; + /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -121,6 +123,39 @@ impl LogicalPlanBuilder { })) } + /// A named temporary relation with a schema. + /// + /// This is used to represent a relation that does not exist at the + /// planning stage, but will be created at execution time with the + /// given schema. + pub fn named_relation(name: &str, schema: DFSchemaRef) -> Self { + Self::from(LogicalPlan::NamedRelation(NamedRelation { + name: name.to_string(), + schema, + })) + } + + /// Convert a regular plan into a recursive query. + pub fn to_recursive_query( + &self, + name: String, + recursive_term: LogicalPlan, + is_distinct: bool, + ) -> Result { + // TODO: we need to do a bunch of validation here. Maybe more. + if is_distinct { + return Err(DataFusionError::NotImplemented( + "Recursive queries with distinct is not supported".to_string(), + )); + } + Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term: Arc::new(self.plan.clone()), + recursive_term: Arc::new(recursive_term), + is_distinct, + }))) + } + /// Create a values list based relation, and the schema is inferred from data, consuming /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index bc722dd69ace..8ef0b522406d 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -35,9 +35,9 @@ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, - JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, - Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, - ToStringifiedPlan, Union, Unnest, Values, Window, + JoinType, Limit, LogicalPlan, NamedRelation, Partitioning, PlanType, Prepare, + Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, + SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 93a38fb40df5..665d2122401f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -112,6 +112,8 @@ pub enum LogicalPlan { /// produces 0 or 1 row. This is used to implement SQL `SELECT` /// that has no values in the `FROM` clause. EmptyRelation(EmptyRelation), + /// A named temporary relation with a schema. + NamedRelation(NamedRelation), /// Produces the output of running another query. This is used to /// implement SQL subqueries Subquery(Subquery), @@ -154,6 +156,8 @@ pub enum LogicalPlan { /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` Unnest(Unnest), + /// A variadic query (e.g. "Recursive CTEs") + RecursiveQuery(RecursiveQuery), } impl LogicalPlan { @@ -191,6 +195,11 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::NamedRelation(NamedRelation { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // we take the schema of the static term as the schema of the entire recursive query + static_term.schema() + } } } @@ -233,6 +242,7 @@ impl LogicalPlan { LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::EmptyRelation(_) + | LogicalPlan::NamedRelation(_) | LogicalPlan::Ddl(_) | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) @@ -243,6 +253,10 @@ impl LogicalPlan { | LogicalPlan::TableScan(_) => { vec![self.schema()] } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // return only the schema of the static term + static_term.all_schemas() + } // return children schemas LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) @@ -384,6 +398,9 @@ impl LogicalPlan { .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) + | LogicalPlan::NamedRelation(_) + // TODO: not sure if this should go here + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) @@ -430,8 +447,14 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => vec![static_term, recursive_term], // plans without inputs LogicalPlan::TableScan { .. } + | LogicalPlan::NamedRelation(_) | LogicalPlan::Statement { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } @@ -510,6 +533,9 @@ impl LogicalPlan { cross.left.head_output_expr() } } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + static_term.head_output_expr() + } LogicalPlan::Union(union) => Ok(Some(Expr::Column( union.schema.fields()[0].qualified_column(), ))), @@ -529,6 +555,7 @@ impl LogicalPlan { } LogicalPlan::Subquery(_) => Ok(None), LogicalPlan::EmptyRelation(_) + | LogicalPlan::NamedRelation(_) | LogicalPlan::Prepare(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) @@ -835,6 +862,14 @@ impl LogicalPlan { }; Ok(LogicalPlan::Distinct(distinct)) } + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, is_distinct, .. + }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(inputs[0].clone()), + recursive_term: Arc::new(inputs[1].clone()), + is_distinct: *is_distinct, + })), LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -873,6 +908,7 @@ impl LogicalPlan { })) } LogicalPlan::EmptyRelation(_) + | LogicalPlan::NamedRelation(_) | LogicalPlan::Ddl(_) | LogicalPlan::Statement(_) => { // All of these plan types have no inputs / exprs so should not be called @@ -1073,6 +1109,9 @@ impl LogicalPlan { }), LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), + // TODO: not sure if this is correct + LogicalPlan::NamedRelation(_) => None, + LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, @@ -1408,6 +1447,14 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::NamedRelation(NamedRelation { name, .. }) => { + write!(f, "NamedRelation: {}", name) + } + LogicalPlan::RecursiveQuery(RecursiveQuery { + is_distinct, .. + }) => { + write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values .iter() @@ -1718,6 +1765,28 @@ pub struct EmptyRelation { pub schema: DFSchemaRef, } +/// A named temporary relation with a known schema. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct NamedRelation { + /// The relation name + pub name: String, + /// The schema description + pub schema: DFSchemaRef, +} + +/// A variadic query operation +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct RecursiveQuery { + /// Name of the query + pub name: String, + /// The static term + pub static_term: Arc, + /// The recursive term + pub recursive_term: Arc, + /// Distinction + pub is_distinct: bool, +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..8113309f90f5 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -364,6 +364,8 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) + | LogicalPlan::NamedRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan utils::optimize_children(self, plan, config)? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 1d4eda0bd23e..10a16ea72251 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -162,6 +162,8 @@ fn optimize_projections( .collect::>() } LogicalPlan::EmptyRelation(_) + | LogicalPlan::NamedRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 6c761fc9687c..e39f592efa4b 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -33,7 +33,9 @@ name = "datafusion_physical_plan" path = "src/lib.rs" [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } @@ -55,6 +57,7 @@ parking_lot = { workspace = true } pin-project-lite = "^0.2.7" rand = { workspace = true } tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] } +tokio-stream = { version = "0.1.14" } uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] diff --git a/datafusion/physical-plan/src/continuance.rs b/datafusion/physical-plan/src/continuance.rs new file mode 100644 index 000000000000..b4fd3ba31985 --- /dev/null +++ b/datafusion/physical-plan/src/continuance.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the continuance query plan + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::Partitioning; +use tokio_stream::wrappers::ReceiverStream; + +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use super::expressions::PhysicalSortExpr; + +use super::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + SendableRecordBatchStream, Statistics, +}; +use datafusion_common::{DataFusionError, Result}; + +/// A temporary "working table" operation where the input data will be +/// taken from the named handle during the execution and will be re-published +/// as is (kind of like a mirror). +/// +/// Most notably used in the implementation of recursive queries where the +/// underlying relation does not exist yet but the data will come as the previous +/// term is evaluated. This table will be used such that the recursive plan +/// will register a receiver in the task context and this plan will use that +/// receiver to get the data and stream it back up so that the batches are available +/// in the next iteration. +#[derive(Debug)] +pub struct ContinuanceExec { + /// Name of the relation handler + name: String, + /// The schema of the stream + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl ContinuanceExec { + /// Create a new execution plan for a continuance stream. The given relation + /// handler must exist in the task context before calling [`ContinuanceExec::execute`] on this + /// plan. + pub fn new(name: String, schema: SchemaRef) -> Self { + Self { + name, + schema, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl DisplayAs for ContinuanceExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ContinuanceExec: name={}", self.name) + } + } + } +} + +impl ExecutionPlan for ContinuanceExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(ContinuanceExec::new( + self.name.clone(), + self.schema.clone(), + ))) + } + + /// This plan does not come with any special streams, but rather we use + /// the existing [`RecordBatchStreamAdapter`] to receive the data from + /// the registered handle. + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Continuance streams must be the plan base. + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "ContinuanceExec got an invalid partition {} (expected 0)", + partition + ))); + } + + // The relation handler must be already registered by the + // parent op. + let receiver = context.pop_relation_handler(self.name.clone())?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + ReceiverStream::new(receiver), + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 1dd1392b9d86..36d820a89ed0 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -49,6 +49,7 @@ pub mod analyze; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; +pub mod continuance; pub mod display; pub mod empty; pub mod explain; @@ -61,6 +62,7 @@ pub mod metrics; mod ordering; pub mod placeholder_row; pub mod projection; +pub mod recursive_query; pub mod repartition; pub mod sorts; pub mod stream; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs new file mode 100644 index 000000000000..39b025f6e1f2 --- /dev/null +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the recursive query plan + +use std::any::Any; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::Partitioning; +use futures::{Stream, StreamExt}; +use tokio::sync::mpsc; + +use super::expressions::PhysicalSortExpr; +use super::metrics::BaselineMetrics; +use super::RecordBatchStream; +use super::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + SendableRecordBatchStream, Statistics, +}; +use arrow::error::ArrowError; +use tokio::sync::mpsc::{Receiver, Sender}; + +use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +/// Recursive query execution plan. +/// +/// This plan has two components: a base part (the static term) and +/// a dynamic part (the recursive term). The execution will start from +/// the base, and as long as the previous iteration produced at least +/// a single new row (taking care of the distinction) the recursive +/// part will be continuously executed. +/// +/// Before each execution of the dynamic part, the rows from the previous +/// iteration will be available in a "working table" (not a real table, +/// can be only accessed using a continuance operation). +/// +/// Note that there won't be any limit or checks applied to detect +/// an infinite recursion, so it is up to the planner to ensure that +/// it won't happen. +#[derive(Debug)] +pub struct RecursiveQueryExec { + /// Name of the query handler + name: String, + /// The base part (static term) + static_term: Arc, + /// The dynamic part (recursive term) + recursive_term: Arc, + /// Distinction + is_distinct: bool, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl RecursiveQueryExec { + /// Create a new RecursiveQueryExec + pub fn new( + name: String, + static_term: Arc, + recursive_term: Arc, + is_distinct: bool, + ) -> Self { + RecursiveQueryExec { + name, + static_term, + recursive_term, + is_distinct, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl ExecutionPlan for RecursiveQueryExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.static_term.schema() + } + + fn children(&self) -> Vec> { + vec![self.static_term.clone(), self.recursive_term.clone()] + } + + // Distribution on a recursive query is really tricky to handle. + // For now, we are going to use a single partition but in the + // future we might find a better way to handle this. + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + // TODO: control these hints and see whether we can + // infer some from the child plans (static/recurisve terms). + fn maintains_input_order(&self) -> Vec { + vec![false, false] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false, false] + } + + fn required_input_distribution(&self) -> Vec { + vec![ + datafusion_physical_expr::Distribution::SinglePartition, + datafusion_physical_expr::Distribution::SinglePartition, + ] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(RecursiveQueryExec::new( + self.name.clone(), + children[0].clone(), + children[1].clone(), + self.is_distinct, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // TODO: we might be able to handle multiple partitions in the future. + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "RecursiveQueryExec got an invalid partition {} (expected 0)", + partition + ))); + } + + let static_stream = self.static_term.execute(partition, context.clone())?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + Ok(Box::pin(RecursiveQueryStream::new( + context, + self.name.clone(), + self.recursive_term.clone(), + static_stream, + baseline_metrics, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for RecursiveQueryExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "RecursiveQueryExec: is_distinct={}", self.is_distinct) + } + } + } +} + +/// The actual logic of the recursive queries happens during the streaming +/// process. A simplified version of the algorithm is the following: +/// +/// buffer = [] +/// +/// while batch := static_stream.next(): +/// buffer.push(batch) +/// yield buffer +/// +/// while buffer.len() > 0: +/// sender, receiver = Channel() +/// register_continuation(handle_name, receiver) +/// sender.send(buffer.drain()) +/// recursive_stream = recursive_term.execute() +/// while batch := recursive_stream.next(): +/// buffer.append(batch) +/// yield buffer +/// +struct RecursiveQueryStream { + /// The context to be used for managing handlers & executing new tasks + task_context: Arc, + /// Name of the relation handler to be used by the recursive term + name: String, + /// The dynamic part (recursive term) as is (without being executed) + recursive_term: Arc, + /// The static part (static term) as a stream. If the processing of this + /// part is completed, then it will be None. + static_stream: Option, + /// The dynamic part (recursive term) as a stream. If the processing of this + /// part has not started yet, or has been completed, then it will be None. + recursive_stream: Option, + /// The schema of the output. + schema: SchemaRef, + /// In-memory buffer for storing a copy of the current results. Will be + /// cleared after each iteration. + buffer: Vec, + // /// Metrics. + _baseline_metrics: BaselineMetrics, +} + +impl RecursiveQueryStream { + /// Create a new recursive query stream + fn new( + task_context: Arc, + name: String, + recursive_term: Arc, + static_stream: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Self { + let schema = static_stream.schema(); + Self { + task_context, + name, + recursive_term, + static_stream: Some(static_stream), + recursive_stream: None, + schema, + buffer: vec![], + _baseline_metrics: baseline_metrics, + } + } + + /// Push a clone of the given batch to the in memory buffer, and then return + /// a poll with it. + fn push_batch( + mut self: std::pin::Pin<&mut Self>, + batch: RecordBatch, + ) -> Poll>> { + self.buffer.push(batch.clone()); + Poll::Ready(Some(Ok(batch))) + } + + /// Start polling for the next iteration, will be called either after the static term + /// is completed or another term is completed. It will follow the algorithm above on + /// to check whether the recursion has ended. + fn poll_next_iteration( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let total_length = self + .buffer + .iter() + .fold(0, |acc, batch| acc + batch.num_rows()); + + if total_length == 0 { + return Poll::Ready(None); + } + + // The initial capacity of the channels is the same as the number of partitions + // we currently hold in the buffer. + let (sender, receiver): ( + Sender>, + Receiver>, + ) = mpsc::channel(self.buffer.len() + 1); + + // There shouldn't be any handlers with this name, since the execution of recursive + // term will immediately consume the relation handler. + self.task_context + .push_relation_handler(self.name.clone(), receiver)?; + + // This part heavily assumes that the buffer is not going to change. Maybe we + // should use a mutex? + for batch in self.buffer.drain(..) { + match sender.try_send(Ok(batch.clone())) { + Ok(_) => {} + Err(e) => { + return Poll::Ready(Some(Err(DataFusionError::ArrowError( + ArrowError::from_external_error(Box::new(e)), + None, + )))); + } + } + } + + // We always execute (and re-execute iteratively) the first partition. + // Downstream plans should not expect any partitioning. + let partition = 0; + + self.recursive_stream = Some( + self.recursive_term + .execute(partition, self.task_context.clone())?, + ); + self.poll_next(cx) + } +} + +impl Stream for RecursiveQueryStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // TODO: we should use this poll to record some metrics! + if let Some(static_stream) = &mut self.static_stream { + // While the static term's stream is available, we'll be forwarding the batches from it (also + // saving them for the initial iteration of the recursive term). + let poll = static_stream.poll_next_unpin(cx); + match &poll { + Poll::Ready(None) => { + // Once this is done, we can start running the setup for the recursive term. + self.static_stream = None; + self.poll_next_iteration(cx) + } + Poll::Ready(Some(Ok(batch))) => self.push_batch(batch.clone()), + _ => poll, + } + } else if let Some(recursive_stream) = &mut self.recursive_stream { + let poll = recursive_stream.poll_next_unpin(cx); + match &poll { + Poll::Ready(None) => { + self.recursive_stream = None; + self.poll_next_iteration(cx) + } + Poll::Ready(Some(Ok(batch))) => self.push_batch(batch.clone()), + _ => poll, + } + } else { + Poll::Ready(None) + } + } +} + +impl RecordBatchStream for RecursiveQueryStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e8a38784481b..d45c1b32e82a 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1702,6 +1702,12 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), + LogicalPlan::NamedRelation(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for NamedRelation", + )), + LogicalPlan::RecursiveQuery(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for RecursiveQuery", + )), } } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index dd4cab126261..0d0eb5b8753b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,13 +20,14 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, + plan_err, sql_err, Constraints, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, + SetQuantifier, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -52,10 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let set_expr = query.body; if let Some(with) = query.with { // Process CTEs from top to bottom - // do not allow self-references - if with.recursive { - return not_impl_err!("Recursive CTEs are not supported"); - } + let is_recursive = with.recursive; for cte in with.cte_tables { // A `WITH` block can't use the same name more than once @@ -65,16 +63,132 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "WITH query name {cte_name:?} specified more than once" ))); } - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; + let cte_query = cte.query; + if is_recursive { + match *cte_query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => { + let distinct = set_quantifier != SetQuantifier::All; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = self + .set_expr_to_plan(*left, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. + + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a schema for the temporary relation + let static_fields = static_plan.schema().fields().clone(); + let static_metadata = static_plan.schema().metadata().clone(); + + let named_relation_schema = DFSchema::new_with_metadata( + // take the fields from the static plan + // but add the cte_name as the qualifier + // so that we can access the fields in the recursive term using + // the cte_name as the qualifier (e.g. table.id) + static_fields + .into_iter() + .map(|field| { + if field.qualifier().is_some() { + field + } else { + field.with_qualifier(cte_name.clone()) + } + }) + .collect(), + static_metadata, + )?; + + let name = cte_name.clone(); - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let named_relation = LogicalPlanBuilder::named_relation( + &name, + Arc::new(named_relation_schema), + ) + .build()?; - planner_context.insert_cte(cte_name, logical_plan); + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), named_relation); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = self + .set_expr_to_plan(*right, &mut planner_context.clone())?; + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let logical_plan = LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build()?; + + let final_plan = + self.apply_table_alias(logical_plan, cte.alias)?; + + // Step 4.2: Remove the temporary relation from the planning context and replace it + // with the final plan. + planner_context.insert_cte(cte_name.clone(), final_plan); + } + _ => { + return Err(DataFusionError::SQL( + ParserError("Invalid recursive CTE".to_string()), + None, + )); + } + }; + } else { + // create logical plan & pass backreferencing CTEs + // CTE expr don't need extend outer_query_schema + let logical_plan = + self.query_to_plan(*cte_query, &mut planner_context.clone())?; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + + planner_context.insert_cte(cte_name, logical_plan); + } } } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 4de08a7124cf..2ec930d0c962 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1387,22 +1387,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn recursive_ctes() { - let sql = " - WITH RECURSIVE numbers AS ( - select 1 as n - UNION ALL - select n + 1 FROM numbers WHERE N < 10 - ) - select * from numbers;"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "This feature is not implemented: Recursive CTEs are not supported", - err.strip_backtrace() - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index c62b56584682..06fac3f594d2 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -19,3 +19,195 @@ query II select * from (WITH source AS (select 1 as e) SELECT * FROM source) t1, (WITH source AS (select 1 as e) SELECT * FROM source) t2 ---- 1 1 + +# trivial recursive CTE works +query I rowsort +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 +) +SELECT * FROM nodes +---- +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 + +# setup +statement ok +CREATE EXTERNAL TABLE beg_account_balance STORED as CSV WITH HEADER ROW LOCATION '../../testing/data/csv/recursive_query_account_beg_2.csv' + +# setup +statement ok +CREATE EXTERNAL TABLE account_balance_growth STORED as CSV WITH HEADER ROW LOCATION '../../testing/data/csv/recursive_query_account_growth_3.csv' + +# recursive CTE with static term derived from table works +query ITI rowsort +WITH RECURSIVE balances AS ( + SELECT * from beg_account_balance + UNION ALL + SELECT time + 1 as time, name, account_balance + 10 as account_balance + FROM balances + WHERE time < 10 +) +SELECT * FROM balances +---- +1 John 100 +1 Tim 200 +10 John 190 +10 Tim 290 +2 John 110 +2 Tim 210 +3 John 120 +3 Tim 220 +4 John 130 +4 Tim 230 +5 John 140 +5 Tim 240 +6 John 150 +6 Tim 250 +7 John 160 +7 Tim 260 +8 John 170 +8 Tim 270 +9 John 180 +9 Tim 280 + + +# recursive CTE with recursive join works +query ITI +WITH RECURSIVE balances AS ( + SELECT time as time, name as name, account_balance as account_balance + FROM beg_account_balance + UNION ALL + SELECT time + 1 as time, balances.name, account_balance + account_balance_growth.account_growth as account_balance + FROM balances + JOIN account_balance_growth + ON balances.name = account_balance_growth.name + WHERE time < 10 +) +SELECT * FROM balances +ORDER BY time, name +---- +1 John 100 +1 Tim 200 +2 John 103 +2 Tim 220 +3 John 106 +3 Tim 240 +4 John 109 +4 Tim 260 +5 John 112 +5 Tim 280 +6 John 115 +6 Tim 300 +7 John 118 +7 Tim 320 +8 John 121 +8 Tim 340 +9 John 124 +9 Tim 360 +10 John 127 +10 Tim 380 + +# recursive CTE with aggregations works +query I rowsort +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 +) +SELECT sum(id) FROM nodes +---- +55 + +# setup +statement ok +CREATE TABLE t(a BIGINT) AS VALUES(1),(2),(3); + +# referencing CTE multiple times does not error +query II rowsort +WITH RECURSIVE my_cte AS ( + SELECT a from t + UNION ALL + SELECT a+2 as a + FROM my_cte + WHERE a<5 +) +SELECT * FROM my_cte t1, my_cte +---- +1 1 +1 2 +1 3 +1 3 +1 4 +1 5 +1 5 +1 6 +2 1 +2 2 +2 3 +2 3 +2 4 +2 5 +2 5 +2 6 +3 1 +3 1 +3 2 +3 2 +3 3 +3 3 +3 3 +3 3 +3 4 +3 4 +3 5 +3 5 +3 5 +3 5 +3 6 +3 6 +4 1 +4 2 +4 3 +4 3 +4 4 +4 5 +4 5 +4 6 +5 1 +5 1 +5 2 +5 2 +5 3 +5 3 +5 3 +5 3 +5 4 +5 4 +5 5 +5 5 +5 5 +5 5 +5 6 +5 6 +6 1 +6 2 +6 3 +6 3 +6 4 +6 5 +6 5 +6 6 diff --git a/testing b/testing index 98fceecd024d..bb8b92eb0ba7 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 98fceecd024dccd2f8a00e32fc144975f218acf4 +Subproject commit bb8b92eb0ba7d9d1ae2348f454d97dd361d36ade