-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Type Coercion for UDF Arguments #14268
base: main
Are you sure you want to change the base?
Changes from 12 commits
bcc0620
a8539c1
4e4cb02
eb61d49
a6f62a0
5afbfc0
36a23b7
d2eadea
944e0a3
2d77206
1a27626
e714ba1
93d75b1
5067223
4d395e2
ec8ccd1
d78877a
041e4ac
437b83d
8fd9fb3
62b97c5
46350c9
3f0c870
46cda71
e7d474d
1f826cb
5d258b2
17df1bc
97c7db0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,12 @@ use crate::utils::make_scalar_function; | |
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; | ||
use arrow::datatypes::DataType; | ||
use arrow::error::ArrowError; | ||
use datafusion_common::types::logical_string; | ||
use datafusion_common::{internal_err, Result}; | ||
use datafusion_expr::{ColumnarValue, Documentation}; | ||
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; | ||
use datafusion_expr::{ | ||
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, | ||
TypeSignatureClass, Volatility, | ||
}; | ||
use datafusion_macros::user_doc; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
@@ -61,7 +64,15 @@ impl Default for AsciiFunc { | |
impl AsciiFunc { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::string(1, Volatility::Immutable), | ||
signature: Signature::one_of( | ||
vec![ | ||
TypeSignature::String(1), | ||
TypeSignature::Coercible(vec![TypeSignatureClass::Native( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I initially did, but after testing on Sail, I discovered new test failures related to coercing input that's all String (e.g. The plan is to port all the relevant tests from Sail into this PR! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jayzhan211 You can find the test failures here if interested! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although it was not my intention to apply this pattern on single arg functions. I'll get that fixed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current design for coercion may still have room for improvement. It would be beneficial to represent the function signature in a simpler and more concise manner, rather than relying on complex combinations of multiple, similar signatures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed! I'll add that in. |
||
logical_string(), | ||
)]), | ||
], | ||
Volatility::Immutable, | ||
), | ||
} | ||
} | ||
} | ||
|
@@ -130,20 +141,23 @@ pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> { | |
let string_array = args[0].as_string_view(); | ||
Ok(calculate_ascii(string_array)?) | ||
} | ||
_ => internal_err!("Unsupported data type"), | ||
other => internal_err!("Unsupported data type for ascii: {:?}", other), | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::expr_fn::ascii; | ||
use crate::string::ascii::AsciiFunc; | ||
use crate::utils::test::test_function; | ||
use arrow::array::{Array, Int32Array}; | ||
use arrow::array::{Array, ArrayRef, Int32Array, RecordBatch, StringArray}; | ||
use arrow::datatypes::DataType::Int32; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; | ||
use datafusion::prelude::SessionContext; | ||
use datafusion_common::{DFSchema, Result, ScalarValue}; | ||
use datafusion_expr::{col, lit, ColumnarValue, ScalarUDFImpl}; | ||
use std::sync::Arc; | ||
|
||
macro_rules! test_ascii { | ||
macro_rules! test_ascii_invoke { | ||
($INPUT:expr, $EXPECTED:expr) => { | ||
test_function!( | ||
AsciiFunc::new(), | ||
|
@@ -174,12 +188,64 @@ mod tests { | |
}; | ||
} | ||
|
||
fn ascii_array(input: ArrayRef) -> Result<ArrayRef> { | ||
let batch = RecordBatch::try_from_iter([("c0", input)])?; | ||
let df_schema = DFSchema::try_from(batch.schema())?; | ||
let expr = ascii(col("c0")); | ||
let physical_expr = | ||
SessionContext::new().create_physical_expr(expr, &df_schema)?; | ||
let result = match physical_expr.evaluate(&batch)? { | ||
ColumnarValue::Array(result) => Ok(result), | ||
_ => datafusion_common::internal_err!("ascii"), | ||
}?; | ||
Ok(result) | ||
} | ||
|
||
fn ascii_scalar(input: ScalarValue) -> Result<ScalarValue> { | ||
let df_schema = DFSchema::empty(); | ||
let expr = ascii(lit(input)); | ||
let physical_expr = | ||
SessionContext::new().create_physical_expr(expr, &df_schema)?; | ||
let result = match physical_expr | ||
.evaluate(&RecordBatch::new_empty(Arc::clone(df_schema.inner())))? | ||
{ | ||
ColumnarValue::Scalar(result) => Ok(result), | ||
_ => datafusion_common::internal_err!("ascii"), | ||
}?; | ||
Ok(result) | ||
} | ||
|
||
#[test] | ||
fn test_functions() -> Result<()> { | ||
test_ascii!(Some(String::from("x")), Ok(Some(120))); | ||
test_ascii!(Some(String::from("a")), Ok(Some(97))); | ||
test_ascii!(Some(String::from("")), Ok(Some(0))); | ||
test_ascii!(None, Ok(None)); | ||
fn test_ascii_invoke() -> Result<()> { | ||
test_ascii_invoke!(Some(String::from("x")), Ok(Some(120))); | ||
test_ascii_invoke!(Some(String::from("a")), Ok(Some(97))); | ||
test_ascii_invoke!(Some(String::from("")), Ok(Some(0))); | ||
test_ascii_invoke!(None, Ok(None)); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_ascii_expr() -> Result<()> { | ||
let input = Arc::new(StringArray::from(vec![Some("x")])) as ArrayRef; | ||
let expected = Arc::new(Int32Array::from(vec![Some(120)])) as ArrayRef; | ||
let result = ascii_array(input)?; | ||
assert_eq!(&expected, &result); | ||
|
||
let input = ScalarValue::Utf8(Some(String::from("x"))); | ||
let expected = ScalarValue::Int32(Some(120)); | ||
let result = ascii_scalar(input)?; | ||
assert_eq!(&expected, &result); | ||
|
||
let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef; | ||
let expected = Arc::new(Int32Array::from(vec![Some(50)])) as ArrayRef; | ||
let result = ascii_array(input)?; | ||
assert_eq!(&expected, &result); | ||
|
||
let input = ScalarValue::Int32(Some(2)); | ||
let expected = ScalarValue::Int32(Some(50)); | ||
let result = ascii_scalar(input)?; | ||
assert_eq!(&expected, &result); | ||
|
||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean others function that used Coercible String now also cast integer to string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's
TypeSignature::Coercible
with aTypeSignatureClass::Native(logical_string())
, then yes. Any function that specifiesTypeSignature::Coercible
with aTypeSignatureClass::Native
should coerce according to the behavior implemented in thedefault_cast_for
function forNativeType
in order to be consistent.