Skip to content

Commit

Permalink
fix: handle scalar predicates in CASE expressions to prevent internal…
Browse files Browse the repository at this point in the history
… errors for InfallibleExprOrNull eval method (#14156)

* fix: handle scalar predicates in CASE expressions to prevent internal errors for InfallibleExprOrNull eval method

* Update to latest datafusion-testing commit

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
avkirilishin and alamb authored Jan 18, 2025
1 parent 5d18648 commit 868fc35
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
2 changes: 1 addition & 1 deletion datafusion-testing
Submodule datafusion-testing updated 44 files
+4 −12 data/sqlite/random/expr/slt_good_102.slt
+2 −6 data/sqlite/random/expr/slt_good_104.slt
+2 −6 data/sqlite/random/expr/slt_good_105.slt
+2 −6 data/sqlite/random/expr/slt_good_108.slt
+2 −6 data/sqlite/random/expr/slt_good_111.slt
+2 −6 data/sqlite/random/expr/slt_good_113.slt
+2 −6 data/sqlite/random/expr/slt_good_114.slt
+5 −6 data/sqlite/random/expr/slt_good_118.slt
+2 −6 data/sqlite/random/expr/slt_good_12.slt
+2 −6 data/sqlite/random/expr/slt_good_14.slt
+2 −6 data/sqlite/random/expr/slt_good_15.slt
+4 −12 data/sqlite/random/expr/slt_good_21.slt
+2 −6 data/sqlite/random/expr/slt_good_22.slt
+4 −12 data/sqlite/random/expr/slt_good_24.slt
+11 −31 data/sqlite/random/expr/slt_good_28.slt
+4 −12 data/sqlite/random/expr/slt_good_30.slt
+2 −6 data/sqlite/random/expr/slt_good_36.slt
+3 −7 data/sqlite/random/expr/slt_good_38.slt
+5 −6 data/sqlite/random/expr/slt_good_39.slt
+2 −6 data/sqlite/random/expr/slt_good_4.slt
+10 −12 data/sqlite/random/expr/slt_good_41.slt
+2 −6 data/sqlite/random/expr/slt_good_42.slt
+2 −6 data/sqlite/random/expr/slt_good_45.slt
+2 −6 data/sqlite/random/expr/slt_good_46.slt
+2 −6 data/sqlite/random/expr/slt_good_48.slt
+2 −6 data/sqlite/random/expr/slt_good_50.slt
+3 −7 data/sqlite/random/expr/slt_good_52.slt
+2 −6 data/sqlite/random/expr/slt_good_53.slt
+4 −12 data/sqlite/random/expr/slt_good_64.slt
+5 −6 data/sqlite/random/expr/slt_good_66.slt
+2 −6 data/sqlite/random/expr/slt_good_68.slt
+6 −7 data/sqlite/random/expr/slt_good_7.slt
+7 −19 data/sqlite/random/expr/slt_good_72.slt
+6 −7 data/sqlite/random/expr/slt_good_73.slt
+2 −6 data/sqlite/random/expr/slt_good_78.slt
+2 −6 data/sqlite/random/expr/slt_good_80.slt
+8 −13 data/sqlite/random/expr/slt_good_81.slt
+2 −6 data/sqlite/random/expr/slt_good_82.slt
+5 −13 data/sqlite/random/expr/slt_good_85.slt
+2 −6 data/sqlite/random/expr/slt_good_88.slt
+2 −6 data/sqlite/random/expr/slt_good_94.slt
+2 −6 data/sqlite/random/expr/slt_good_95.slt
+1 −5 data/sqlite/random/groupby/slt_good_11.slt
+9 −21 data/sqlite/random/groupby/slt_good_12.slt
58 changes: 57 additions & 1 deletion datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,16 @@ impl CaseExpr {
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let when_expr = &self.when_then_expr[0].0;
let then_expr = &self.when_then_expr[0].1;
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {

let when_expr_value = when_expr.evaluate(batch)?;
let when_expr_value = match when_expr_value {
ColumnarValue::Scalar(_) => {
ColumnarValue::Array(when_expr_value.into_array(batch.num_rows())?)
}
other => other,
};

if let ColumnarValue::Array(bit_mask) = when_expr_value {
let bit_mask = bit_mask
.as_any()
.downcast_ref::<BooleanArray>()
Expand Down Expand Up @@ -896,6 +905,53 @@ mod tests {
Ok(())
}

#[test]
fn case_with_scalar_predicate() -> Result<()> {
let batch = case_test_batch_nulls()?;
let schema = batch.schema();

// SELECT CASE WHEN TRUE THEN load4 END
let when = lit(true);
let then = col("load4", &schema)?;
let expr = generate_case_when_with_type_coercion(
None,
vec![(when, then)],
None,
schema.as_ref(),
)?;

// many rows
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected = &Float64Array::from(vec![
Some(1.77),
None,
None,
Some(1.78),
None,
Some(1.77),
]);
assert_eq!(expected, result);

// one row
let expected = Float64Array::from(vec![Some(1.1)]);
let batch =
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
assert_eq!(&expected, result);

Ok(())
}

#[test]
fn case_expr_matches_and_nulls() -> Result<()> {
let batch = case_test_batch_nulls()?;
Expand Down

0 comments on commit 868fc35

Please sign in to comment.