diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 776238952dd7..7001df7f72af 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2721,30 +2721,48 @@ pub struct Unnest { #[cfg(test)] mod tests { - use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{ + col, count, exists, in_subquery, lit, max, placeholder, sum, GroupingSet, + }; + use std::sync::OnceLock; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; - fn employee_schema() -> Schema { - Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::Int32, false), - ]) + static EMPLOYEE_SCHEMA: OnceLock = OnceLock::new(); + fn employee_schema() -> &'static Schema { + EMPLOYEE_SCHEMA.get_or_init(|| { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ]) + }) + } + + static ID_SCHEMA: OnceLock = OnceLock::new(); + fn id_schema() -> &'static Schema { + ID_SCHEMA + .get_or_init(|| Schema::new(vec![Field::new("id", DataType::Int32, false)])) + } + + static FIRST_NAME_SCHEMA: OnceLock = OnceLock::new(); + fn first_name_schema() -> &'static Schema { + FIRST_NAME_SCHEMA.get_or_init(|| { + Schema::new(vec![Field::new("first_name", DataType::Utf8, false)]) + }) } fn display_plan() -> Result { - let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))? + let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))? .build()?; - table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))? + table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))? .filter(in_subquery(col("state"), Arc::new(plan1)))? .project(vec![col("id")])? .build() @@ -2780,14 +2798,13 @@ mod tests { #[test] fn test_display_subquery_alias() -> Result<()> { - let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))? + let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))? .build()?; let plan1 = Arc::new(plan1); - let plan = - table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))? - .project(vec![col("id"), exists(plan1).alias("exists")])? - .build(); + let plan = table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))? + .project(vec![col("id"), exists(plan1).alias("exists")])? + .build(); let expected = "Projection: employee_csv.id, EXISTS () AS exists\ \n Subquery:\ @@ -3333,4 +3350,86 @@ digraph { let actual = format!("{}", plan.display_indent()); assert_eq!(expected.to_string(), actual) } + + #[test] + fn recompute_schema_projection() -> Result<()> { + // SELECT id FROM employee_csv + let plan = table_scan(Some("employee_csv"), employee_schema(), None)? + .project(vec![col("id")])? + .build()?; + assert_eq!(plan.schema().as_arrow(), id_schema()); + + // rewrite to SELECT first_name FROM employee_csv + let plan = plan + .map_expressions(|_| Ok(Transformed::yes(col("first_name"))))? + .data; + + // before recompute_schema, the schema is still the same + assert_eq!(plan.schema().as_arrow(), id_schema()); + let plan = plan.recompute_schema()?; + assert_eq!(plan.schema().as_arrow(), first_name_schema()); + + Ok(()) + } + + #[test] + fn recompute_schema_window() -> Result<()> { + // SELECT id, SUM(salary) OVER () FROM employee_csv + let plan = table_scan(Some("employee_csv"), employee_schema(), None)? + .project(vec![col("id"), col("salary")])? + .window(vec![sum(col("salary"))])? + .build()?; + + // rewrite to SELECT id, MAX(salary) OVER () FROM employee_csv + let plan = plan + .map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))? + .data; + + // before recompute_schema, the schema should be SUM + let expected_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("salary", DataType::Int32, false), + Field::new("SUM(employee_csv.salary)", DataType::Int64, true), + ]); + assert_eq!(plan.schema().as_arrow(), &expected_schema); + + // after recompute_schema, the schema should be MAX + let plan = plan.recompute_schema()?; + let expected_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("salary", DataType::Int32, false), + Field::new("MAX(salary)", DataType::Int32, true), + ]); + assert_eq!(plan.schema().as_arrow(), &expected_schema); + Ok(()) + } + + #[test] + fn recompute_schema_aggregate() -> Result<()> { + // SELECT sum(salary) from employee_csv + let plan = table_scan(Some("employee_csv"), employee_schema(), None)? + .project(vec![col("salary")])? + .aggregate(vec![] as Vec, vec![sum(col("salary"))])? + .build()?; + + // rewrite to MAX(salary) FROM employee_csv + let plan = plan + .map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))? + .data; + + // before recompute_schema, the schema should be SUM + let expected_schema = Schema::new(vec![Field::new( + "SUM(employee_csv.salary)", + DataType::Int64, + true, + )]); + assert_eq!(plan.schema().as_arrow(), &expected_schema); + + // after recompute_schema, the schema should be MAX + let plan = plan.recompute_schema()?; + let expected_schema = + Schema::new(vec![Field::new("MAX(salary)", DataType::Int32, true)]); + assert_eq!(plan.schema().as_arrow(), &expected_schema); + Ok(()) + } }