Skip to content

Commit

Permalink
test: interval analysis unit tests (#14189)
Browse files Browse the repository at this point in the history
Added unit tests to interval analysis method which converts Expr tree
to a set of Intervals for columns in a given schema.
  • Loading branch information
hiltontj authored Jan 22, 2025
1 parent 361727a commit 5edb276
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions datafusion/physical-expr/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,119 @@ fn calculate_selectivity(
acc * cardinality_ratio(&initial.interval, &target.interval)
})
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema};
use datafusion_expr::{
col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, Expr,
};

use crate::{create_physical_expr, AnalysisContext};

use super::{analyze, ExprBoundaries};

fn make_field(name: &str, data_type: DataType) -> Field {
let nullable = false;
Field::new(name, data_type, nullable)
}

#[test]
fn test_analyze_boundary_exprs() {
let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)]));

/// Test case containing (expression tree, lower bound, upper bound)
type TestCase = (Expr, Option<i32>, Option<i32>);

let test_cases: Vec<TestCase> = vec![
// a > 10
(col("a").gt(lit(10)), Some(11), None),
// a < 20
(col("a").lt(lit(20)), None, Some(19)),
// a > 10 AND a < 20
(
col("a").gt(lit(10)).and(col("a").lt(lit(20))),
Some(11),
Some(19),
),
// a >= 10
(col("a").gt_eq(lit(10)), Some(10), None),
// a <= 20
(col("a").lt_eq(lit(20)), None, Some(20)),
// a >= 10 AND a <= 20
(
col("a").gt_eq(lit(10)).and(col("a").lt_eq(lit(20))),
Some(10),
Some(20),
),
// a > 10 AND a < 20 AND a < 15
(
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").lt(lit(15))),
Some(11),
Some(14),
),
// (a > 10 AND a < 20) AND (a > 15 AND a < 25)
(
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").gt(lit(15)))
.and(col("a").lt(lit(25))),
Some(16),
Some(19),
),
// (a > 10 AND a < 20) AND (a > 20 AND a < 30)
(
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").gt(lit(20)))
.and(col("a").lt(lit(30))),
None,
None,
),
];
for (expr, lower, upper) in test_cases {
let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap();
let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap();
let physical_expr =
create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap();
let analysis_result = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)
.unwrap();
let actual = &analysis_result.boundaries[0].interval;
let expected = Interval::make(lower, upper).unwrap();
assert_eq!(
&expected, actual,
"did not get correct interval for SQL expression: {expr:?}"
);
}
}

#[test]
fn test_analyze_invalid_boundary_exprs() {
let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)]));
let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20)));
let expected_error = "Interval arithmetic does not support the operator OR";
let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap();
let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap();
let physical_expr =
create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap();
let analysis_error = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)
.unwrap_err();
assert_contains!(analysis_error.to_string(), expected_error);
}
}

0 comments on commit 5edb276

Please sign in to comment.