Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 7, 2024
1 parent 826d51f commit 4687d6d
Showing 1 changed file with 116 additions and 17 deletions.
133 changes: 116 additions & 17 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2721,30 +2721,48 @@ pub struct Unnest {

#[cfg(test)]
mod tests {

use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use crate::{
col, count, exists, in_subquery, lit, max, placeholder, sum, GroupingSet,
};
use std::sync::OnceLock;

use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, Constraint, ScalarValue};

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
static EMPLOYEE_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn employee_schema() -> &'static Schema {
EMPLOYEE_SCHEMA.get_or_init(|| {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
})
}

static ID_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn id_schema() -> &'static Schema {
ID_SCHEMA
.get_or_init(|| Schema::new(vec![Field::new("id", DataType::Int32, false)]))
}

static FIRST_NAME_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn first_name_schema() -> &'static Schema {
FIRST_NAME_SCHEMA.get_or_init(|| {
Schema::new(vec![Field::new("first_name", DataType::Utf8, false)])
})
}

fn display_plan() -> Result<LogicalPlan> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))?
.build()?;

table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))?
.filter(in_subquery(col("state"), Arc::new(plan1)))?
.project(vec![col("id")])?
.build()
Expand Down Expand Up @@ -2780,14 +2798,13 @@ mod tests {

#[test]
fn test_display_subquery_alias() -> Result<()> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))?
.build()?;
let plan1 = Arc::new(plan1);

let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();
let plan = table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();

let expected = "Projection: employee_csv.id, EXISTS (<subquery>) AS exists\
\n Subquery:\
Expand Down Expand Up @@ -3333,4 +3350,86 @@ digraph {
let actual = format!("{}", plan.display_indent());
assert_eq!(expected.to_string(), actual)
}

#[test]
fn recompute_schema_projection() -> Result<()> {
// SELECT id FROM employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("id")])?
.build()?;
assert_eq!(plan.schema().as_arrow(), id_schema());

// rewrite to SELECT first_name FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(col("first_name"))))?
.data;

// before recompute_schema, the schema is still the same
assert_eq!(plan.schema().as_arrow(), id_schema());
let plan = plan.recompute_schema()?;
assert_eq!(plan.schema().as_arrow(), first_name_schema());

Ok(())
}

#[test]
fn recompute_schema_window() -> Result<()> {
// SELECT id, SUM(salary) OVER () FROM employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("id"), col("salary")])?
.window(vec![sum(col("salary"))])?
.build()?;

// rewrite to SELECT id, MAX(salary) OVER () FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))?
.data;

// before recompute_schema, the schema should be SUM
let expected_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("salary", DataType::Int32, false),
Field::new("SUM(employee_csv.salary)", DataType::Int64, true),
]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);

// after recompute_schema, the schema should be MAX
let plan = plan.recompute_schema()?;
let expected_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("salary", DataType::Int32, false),
Field::new("MAX(salary)", DataType::Int32, true),
]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);
Ok(())
}

#[test]
fn recompute_schema_aggregate() -> Result<()> {
// SELECT sum(salary) from employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("salary")])?
.aggregate(vec![] as Vec<Expr>, vec![sum(col("salary"))])?
.build()?;

// rewrite to MAX(salary) FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))?
.data;

// before recompute_schema, the schema should be SUM
let expected_schema = Schema::new(vec![Field::new(
"SUM(employee_csv.salary)",
DataType::Int64,
true,
)]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);

// after recompute_schema, the schema should be MAX
let plan = plan.recompute_schema()?;
let expected_schema =
Schema::new(vec![Field::new("MAX(salary)", DataType::Int32, true)]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);
Ok(())
}
}

0 comments on commit 4687d6d

Please sign in to comment.