From e5611b409348037d50a8c80bed6cef07a5a5ed5b Mon Sep 17 00:00:00 2001 From: Matthew Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:41:18 -0800 Subject: [PATCH] fix issue where CTE could not be referenced more than 1 time --- datafusion/core/src/physical_planner.rs | 84 +++++++++++++++------- datafusion/sql/src/query.rs | 18 ++--- datafusion/sqllogictest/test_files/cte.slt | 80 +++++++++++++++++++++ 3 files changed, 147 insertions(+), 35 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fdb757347e793..01c6e6f306643 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,7 +19,8 @@ use std::collections::HashMap; use std::fmt::Write; -use std::sync::Arc; +use std::sync::atomic::AtomicI32; +use std::sync::{Arc, OnceLock}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::avro::AvroFormat; @@ -89,8 +90,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, NamedRelation, RecursiveQuery, + DescribeTable, DmlStatement, NamedRelation, RecursiveQuery, ScalarFunctionDefinition, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -452,11 +453,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result> { + reset_recursive_cte_physical_plan_branch_number(); + match self.handle_explain(logical_plan, session_state).await? { Some(plan) => Ok(plan), None => { let plan = self - .create_initial_plan(logical_plan, session_state) + .create_initial_plan(logical_plan, session_state, None) .await?; self.optimize_internal(plan, session_state, |_, _| {}) } @@ -487,6 +490,23 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { } } +// atomic global incrmenter + +static RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH: OnceLock = OnceLock::new(); + +fn new_recursive_cte_physical_plan_branch_number() -> u32 { + let counter = RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH + .get_or_init(|| AtomicI32::new(0)) + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + counter as u32 +} + +fn reset_recursive_cte_physical_plan_branch_number() { + RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH + .get_or_init(|| AtomicI32::new(0)) + .store(0, std::sync::atomic::Ordering::SeqCst); +} + impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to /// plan user-defined logical nodes [`LogicalPlan::Extension`]. @@ -507,6 +527,7 @@ impl DefaultPhysicalPlanner { &'a self, logical_plans: impl IntoIterator + Send + 'a, session_state: &'a SessionState, + ctx: Option<&'a String>, ) -> BoxFuture<'a, Result>>> { async move { // First build futures with as little references as possible, then performing some stream magic. @@ -519,7 +540,7 @@ impl DefaultPhysicalPlanner { .into_iter() .enumerate() .map(|(idx, lp)| async move { - let plan = self.create_initial_plan(lp, session_state).await?; + let plan = self.create_initial_plan(lp, session_state, ctx).await?; Ok((idx, plan)) as Result<_> }) .collect::>(); @@ -548,6 +569,7 @@ impl DefaultPhysicalPlanner { &'a self, logical_plan: &'a LogicalPlan, session_state: &'a SessionState, + ctx: Option<&'a String>, ) -> BoxFuture<'a, Result>> { async move { let exec_plan: Result> = match logical_plan { @@ -572,7 +594,7 @@ impl DefaultPhysicalPlanner { single_file_output, copy_options, }) => { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); @@ -620,7 +642,7 @@ impl DefaultPhysicalPlanner { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; provider.insert_into(session_state, input_exec, false).await } else { return exec_err!( @@ -637,7 +659,7 @@ impl DefaultPhysicalPlanner { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; provider.insert_into(session_state, input_exec, true).await } else { return exec_err!( @@ -678,7 +700,7 @@ impl DefaultPhysicalPlanner { ); } - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; // at this moment we are guaranteed by the logical planner // to have all the window_expr to have equal sort key @@ -774,7 +796,7 @@ impl DefaultPhysicalPlanner { .. }) => { // Initially need to perform the aggregate and then merge the partitions - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); @@ -848,7 +870,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Projection(Projection { input, expr, .. }) => { - let input_exec = self.create_initial_plan(input, session_state).await?; + let input_exec = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = input.as_ref().schema(); let physical_exprs = expr @@ -900,7 +922,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Filter(filter) => { - let physical_input = self.create_initial_plan(&filter.input, session_state).await?; + let physical_input = self.create_initial_plan(&filter.input, session_state, ctx).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = filter.input.schema(); @@ -914,8 +936,8 @@ impl DefaultPhysicalPlanner { let filter = FilterExec::try_new(runtime_expr, physical_input)?; Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, .. }) => { - let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; + LogicalPlan::Union(Union { inputs, schema }) => { + let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state, ctx).await?; Ok(Arc::new(UnionExec::new(physical_plans))) } @@ -923,7 +945,7 @@ impl DefaultPhysicalPlanner { input, partitioning_scheme, }) => { - let physical_input = self.create_initial_plan(input, session_state).await?; + let physical_input = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -954,7 +976,7 @@ impl DefaultPhysicalPlanner { )?)) } LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { - let physical_input = self.create_initial_plan(input, session_state).await?; + let physical_input = self.create_initial_plan(input, session_state, ctx).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = expr @@ -1045,12 +1067,12 @@ impl DefaultPhysicalPlanner { }; return self - .create_initial_plan(&join_plan, session_state) + .create_initial_plan(&join_plan, session_state, ctx) .await; } // All equi-join keys are columns now, create physical join plan - let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?; + let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?; let [physical_left, physical_right]: [Arc; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?; let left_df_schema = left.schema(); let right_df_schema = right.schema(); @@ -1185,7 +1207,7 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?; + let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?; let [left, right]: [Arc; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?; Ok(Arc::new(CrossJoinExec::new(left, right))) } @@ -1203,10 +1225,10 @@ impl DefaultPhysicalPlanner { SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { - self.create_initial_plan(input, session_state).await + self.create_initial_plan(input, session_state, ctx).await } LogicalPlan::Limit(Limit { input, skip, fetch, .. }) => { - let input = self.create_initial_plan(input, session_state).await?; + let input = self.create_initial_plan(input, session_state, ctx).await?; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -1224,7 +1246,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } LogicalPlan::Unnest(Unnest { input, column, schema, options }) => { - let input = self.create_initial_plan(input, session_state).await?; + let input = self.create_initial_plan(input, session_state, ctx).await?; let column_exec = schema.index_of_column(column) .map(|idx| Column::new(&column.name, idx))?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); @@ -1277,7 +1299,7 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Analyze must be root of the plan" ), LogicalPlan::Extension(e) => { - let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state).await?; + let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state, ctx).await?; let mut maybe_plan = None; for planner in &self.extension_planners { @@ -1313,13 +1335,19 @@ impl DefaultPhysicalPlanner { Ok(plan) } } + // LogicalPlan::SubqueryAlias(SubqueryAlias()) LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct }) => { - let static_term = self.create_initial_plan(static_term, session_state).await?; - let recursive_term = self.create_initial_plan(recursive_term, session_state).await?; + let name = format!("{}-{}", name, new_recursive_cte_physical_plan_branch_number()); + + let ctx = Some(&name); + + let static_term = self.create_initial_plan(static_term, session_state, ctx).await?; + let recursive_term = self.create_initial_plan(recursive_term, session_state, ctx).await?; Ok(Arc::new(RecursiveQueryExec::new(name.clone(), static_term, recursive_term, *is_distinct))) } - LogicalPlan::NamedRelation(NamedRelation {name, schema}) => { + LogicalPlan::NamedRelation(NamedRelation {schema, ..}) => { + let name = ctx.expect("NamedRelation must have a context that contains the recursive query's branch name"); // Named relations is how we represent access to any sort of dynamic data provider. They // differ from tables in the sense that they can start existing dynamically during the // execution of a query and then disappear before it even finishes. @@ -1866,6 +1894,8 @@ impl DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result>> { + reset_recursive_cte_physical_plan_branch_number(); + if let LogicalPlan::Explain(e) = logical_plan { use PlanType::*; let mut stringified_plans = vec![]; @@ -1881,7 +1911,7 @@ impl DefaultPhysicalPlanner { if !config.logical_plan_only && e.logical_optimization_succeeded { match self - .create_initial_plan(e.plan.as_ref(), session_state) + .create_initial_plan(e.plan.as_ref(), session_state, None) .await { Ok(input) => { diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 413a9efca781b..8bef4ea119384 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -23,7 +23,8 @@ use datafusion_common::{ plan_err, sql_err, Constraints, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, + logical_plan, CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, + LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, @@ -133,10 +134,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { static_metadata, )?; + let name = cte_name.clone(); + // Step 2.2: Create a temporary relation logical plan that will be used // as the input to the recursive term let named_relation = LogicalPlanBuilder::named_relation( - cte_name.as_str(), + &name, Arc::new(named_relation_schema), ) .build()?; @@ -157,14 +160,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // ---------- Step 4: Create the final plan ------------------ // Step 4.1: Compile the final plan - let final_plan = LogicalPlanBuilder::from(static_plan) - .to_recursive_query( - cte_name.clone(), - recursive_plan, - distinct, - )? + let logical_plan = LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? .build()?; + let final_plan = + self.apply_table_alias(logical_plan, cte.alias)?; + // Step 4.2: Remove the temporary relation from the planning context and replace it // with the final plan. planner_context.insert_cte(cte_name.clone(), final_plan); diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 52ea127e1cf6f..06fac3f594d2f 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -131,3 +131,83 @@ WITH RECURSIVE nodes AS ( SELECT sum(id) FROM nodes ---- 55 + +# setup +statement ok +CREATE TABLE t(a BIGINT) AS VALUES(1),(2),(3); + +# referencing CTE multiple times does not error +query II rowsort +WITH RECURSIVE my_cte AS ( + SELECT a from t + UNION ALL + SELECT a+2 as a + FROM my_cte + WHERE a<5 +) +SELECT * FROM my_cte t1, my_cte +---- +1 1 +1 2 +1 3 +1 3 +1 4 +1 5 +1 5 +1 6 +2 1 +2 2 +2 3 +2 3 +2 4 +2 5 +2 5 +2 6 +3 1 +3 1 +3 2 +3 2 +3 3 +3 3 +3 3 +3 3 +3 4 +3 4 +3 5 +3 5 +3 5 +3 5 +3 6 +3 6 +4 1 +4 2 +4 3 +4 3 +4 4 +4 5 +4 5 +4 6 +5 1 +5 1 +5 2 +5 2 +5 3 +5 3 +5 3 +5 3 +5 4 +5 4 +5 5 +5 5 +5 5 +5 5 +5 6 +5 6 +6 1 +6 2 +6 3 +6 3 +6 4 +6 5 +6 5 +6 6