From 868fc35ec7e0dac94a6fbd21a8a62ed09e1024dd Mon Sep 17 00:00:00 2001 From: Aleksey Kirilishin <54231417+avkirilishin@users.noreply.github.com> Date: Sat, 18 Jan 2025 20:02:25 +0300 Subject: [PATCH] fix: handle scalar predicates in CASE expressions to prevent internal errors for InfallibleExprOrNull eval method (#14156) * fix: handle scalar predicates in CASE expressions to prevent internal errors for InfallibleExprOrNull eval method * Update to latest datafusion-testing commit --------- Co-authored-by: Andrew Lamb --- datafusion-testing | 2 +- .../physical-expr/src/expressions/case.rs | 58 ++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/datafusion-testing b/datafusion-testing index 36283d195c72..5b424aefd7f6 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit 36283d195c728f26b16b517ba999fd62509b6649 +Subproject commit 5b424aefd7f6bf198220c37f59d39dbb25b47695 diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 16b97c203c30..be1043d09cc1 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -344,7 +344,16 @@ impl CaseExpr { fn case_column_or_null(&self, batch: &RecordBatch) -> Result { let when_expr = &self.when_then_expr[0].0; let then_expr = &self.when_then_expr[0].1; - if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? { + + let when_expr_value = when_expr.evaluate(batch)?; + let when_expr_value = match when_expr_value { + ColumnarValue::Scalar(_) => { + ColumnarValue::Array(when_expr_value.into_array(batch.num_rows())?) + } + other => other, + }; + + if let ColumnarValue::Array(bit_mask) = when_expr_value { let bit_mask = bit_mask .as_any() .downcast_ref::() @@ -896,6 +905,53 @@ mod tests { Ok(()) } + #[test] + fn case_with_scalar_predicate() -> Result<()> { + let batch = case_test_batch_nulls()?; + let schema = batch.schema(); + + // SELECT CASE WHEN TRUE THEN load4 END + let when = lit(true); + let then = col("load4", &schema)?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + None, + schema.as_ref(), + )?; + + // many rows + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + let expected = &Float64Array::from(vec![ + Some(1.77), + None, + None, + Some(1.78), + None, + Some(1.77), + ]); + assert_eq!(expected, result); + + // one row + let expected = Float64Array::from(vec![Some(1.1)]); + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = + as_float64_array(&result).expect("failed to downcast to Float64Array"); + assert_eq!(&expected, result); + + Ok(()) + } + #[test] fn case_expr_matches_and_nulls() -> Result<()> { let batch = case_test_batch_nulls()?;