Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Type Coercion for UDF Arguments #14268

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcc0620
Fix DF 43 regression coerce ascii input to string
shehabgamin Jan 22, 2025
a8539c1
datafusion-testing submodule has new commits
shehabgamin Jan 22, 2025
4e4cb02
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 23, 2025
eb61d49
implicit cast int to string
shehabgamin Jan 24, 2025
a6f62a0
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 24, 2025
5afbfc0
fix can_coerce_to and add tests
shehabgamin Jan 24, 2025
36a23b7
update deprecation message for values exec
shehabgamin Jan 24, 2025
d2eadea
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 24, 2025
944e0a3
lint
shehabgamin Jan 24, 2025
2d77206
coerce to string
shehabgamin Jan 24, 2025
1a27626
Adjust type signature
shehabgamin Jan 24, 2025
e714ba1
fix comment
shehabgamin Jan 24, 2025
93d75b1
clean up clippy warnings
shehabgamin Jan 25, 2025
5067223
type signature coercible
shehabgamin Jan 25, 2025
4d395e2
bump pyo3
shehabgamin Jan 25, 2025
ec8ccd1
udf type coercion
shehabgamin Jan 25, 2025
d78877a
moving testing out of functions crate due to circular dependencies
shehabgamin Jan 25, 2025
041e4ac
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 26, 2025
437b83d
find the common string type for TypeSignature::Coercible
shehabgamin Jan 26, 2025
8fd9fb3
update coercible string
shehabgamin Jan 26, 2025
62b97c5
fix error msg and add tests for udfs
shehabgamin Jan 26, 2025
46350c9
update docs to note that args are coercible string
shehabgamin Jan 26, 2025
3f0c870
update expr test
shehabgamin Jan 26, 2025
46cda71
undo
shehabgamin Jan 27, 2025
e7d474d
Merge branch 'main' of github.com:lakehq/datafusion into sail-df-43-r…
shehabgamin Jan 27, 2025
1f826cb
remove test since already covered in slt
shehabgamin Jan 27, 2025
5d258b2
add dictionary to base_yupe
shehabgamin Jan 27, 2025
17df1bc
add dictionary to base_type
shehabgamin Jan 27, 2025
97c7db0
fix lint issues
shehabgamin Jan 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl Signature {
}
}

/// A specified number of numeric arguments
/// A specified number of string arguments
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::String(arg_count),
Expand Down
32 changes: 15 additions & 17 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,23 +584,21 @@ fn get_valid_types(
match target_type_class {
TypeSignatureClass::Native(native_type) => {
let target_type = native_type.native();
if &logical_type == target_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean others function that used Coercible String now also cast integer to string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's TypeSignature::Coercible with a TypeSignatureClass::Native(logical_string()), then yes. Any function that specifies TypeSignature::Coercible with a TypeSignatureClass::Native should coerce according to the behavior implemented in the default_cast_for function for NativeType in order to be consistent.

return target_type.default_cast_for(current_type);
}

if logical_type == NativeType::Null {
return target_type.default_cast_for(current_type);
}

if target_type.is_integer() && logical_type.is_integer() {
return target_type.default_cast_for(current_type);
}

internal_err!(
"Expect {} but received {}",
target_type_class,
current_type
)
target_type.default_cast_for(current_type)
// if &logical_type == target_type {
// return target_type.default_cast_for(current_type);
// }
// if logical_type == NativeType::Null {
// return target_type.default_cast_for(current_type);
// }
// if target_type.is_integer() && logical_type.is_integer() {
// return target_type.default_cast_for(current_type);
// }
// internal_err!(
// "Expect {} but received {}",
// target_type_class,
// current_type
// )
}
// Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp
TypeSignatureClass::Timestamp
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ uuid = { version = "1.7", features = ["v4"], optional = true }
[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
criterion = "0.5"
datafusion = { workspace = true, default-features = false }
rand = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "sync"] }

Expand Down
92 changes: 79 additions & 13 deletions datafusion/functions/src/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ use crate::utils::make_scalar_function;
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{ColumnarValue, Documentation};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -61,7 +64,15 @@ impl Default for AsciiFunc {
impl AsciiFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(1),
TypeSignature::Coercible(vec![TypeSignatureClass::Native(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use coercible(string), we don't need string since it is a more strict rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I initially did, but after testing on Sail, I discovered new test failures related to coercing input that's all String (e.g. func(Utf8, Utf8View)).

The plan is to port all the relevant tests from Sail into this PR!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it was not my intention to apply this pattern on single arg functions. I'll get that fixed!

Copy link
Contributor

@jayzhan211 jayzhan211 Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design for coercion may still have room for improvement. It would be beneficial to represent the function signature in a simpler and more concise manner, rather than relying on complex combinations of multiple, similar signatures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! I'll add that in.

logical_string(),
)]),
],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -130,20 +141,23 @@ pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = args[0].as_string_view();
Ok(calculate_ascii(string_array)?)
}
_ => internal_err!("Unsupported data type"),
other => internal_err!("Unsupported data type for ascii: {:?}", other),
}
}

#[cfg(test)]
mod tests {
use crate::expr_fn::ascii;
use crate::string::ascii::AsciiFunc;
use crate::utils::test::test_function;
use arrow::array::{Array, Int32Array};
use arrow::array::{Array, ArrayRef, Int32Array, RecordBatch, StringArray};
use arrow::datatypes::DataType::Int32;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, Result, ScalarValue};
use datafusion_expr::{col, lit, ColumnarValue, ScalarUDFImpl};
use std::sync::Arc;

macro_rules! test_ascii {
macro_rules! test_ascii_invoke {
($INPUT:expr, $EXPECTED:expr) => {
test_function!(
AsciiFunc::new(),
Expand Down Expand Up @@ -174,12 +188,64 @@ mod tests {
};
}

fn ascii_array(input: ArrayRef) -> Result<ArrayRef> {
let batch = RecordBatch::try_from_iter([("c0", input)])?;
let df_schema = DFSchema::try_from(batch.schema())?;
let expr = ascii(col("c0"));
let physical_expr =
SessionContext::new().create_physical_expr(expr, &df_schema)?;
let result = match physical_expr.evaluate(&batch)? {
ColumnarValue::Array(result) => Ok(result),
_ => datafusion_common::internal_err!("ascii"),
}?;
Ok(result)
}

fn ascii_scalar(input: ScalarValue) -> Result<ScalarValue> {
let df_schema = DFSchema::empty();
let expr = ascii(lit(input));
let physical_expr =
SessionContext::new().create_physical_expr(expr, &df_schema)?;
let result = match physical_expr
.evaluate(&RecordBatch::new_empty(Arc::clone(df_schema.inner())))?
{
ColumnarValue::Scalar(result) => Ok(result),
_ => datafusion_common::internal_err!("ascii"),
}?;
Ok(result)
}

#[test]
fn test_functions() -> Result<()> {
test_ascii!(Some(String::from("x")), Ok(Some(120)));
test_ascii!(Some(String::from("a")), Ok(Some(97)));
test_ascii!(Some(String::from("")), Ok(Some(0)));
test_ascii!(None, Ok(None));
fn test_ascii_invoke() -> Result<()> {
test_ascii_invoke!(Some(String::from("x")), Ok(Some(120)));
test_ascii_invoke!(Some(String::from("a")), Ok(Some(97)));
test_ascii_invoke!(Some(String::from("")), Ok(Some(0)));
test_ascii_invoke!(None, Ok(None));
Ok(())
}

#[test]
fn test_ascii_expr() -> Result<()> {
let input = Arc::new(StringArray::from(vec![Some("x")])) as ArrayRef;
let expected = Arc::new(Int32Array::from(vec![Some(120)])) as ArrayRef;
let result = ascii_array(input)?;
assert_eq!(&expected, &result);

let input = ScalarValue::Utf8(Some(String::from("x")));
let expected = ScalarValue::Int32(Some(120));
let result = ascii_scalar(input)?;
assert_eq!(&expected, &result);

let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
let expected = Arc::new(Int32Array::from(vec![Some(50)])) as ArrayRef;
let result = ascii_array(input)?;
assert_eq!(&expected, &result);

let input = ScalarValue::Int32(Some(2));
let expected = ScalarValue::Int32(Some(50));
let result = ascii_scalar(input)?;
assert_eq!(&expected, &result);

Ok(())
}
}
17 changes: 14 additions & 3 deletions datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ use arrow::datatypes::DataType;
use std::any::Any;

use crate::utils::utf8_to_int_type;
use datafusion_common::types::logical_string;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
Expand Down Expand Up @@ -55,7 +58,15 @@ impl Default for BitLengthFunc {
impl BitLengthFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(1),
TypeSignature::Coercible(vec![TypeSignatureClass::Native(
logical_string(),
)]),
],
Volatility::Immutable,
),
}
}
}
Expand Down
15 changes: 13 additions & 2 deletions datafusion/functions/src/string/contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ use arrow::compute::contains as arrow_contains;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
use datafusion_common::exec_err;
use datafusion_common::types::logical_string;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -59,7 +61,16 @@ impl Default for ContainsFunc {
impl ContainsFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(2),
TypeSignature::Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_string()),
]),
],
Volatility::Immutable,
),
}
}
}
Expand Down
18 changes: 15 additions & 3 deletions datafusion/functions/src/string/ends_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
Expand Down Expand Up @@ -62,7 +65,16 @@ impl Default for EndsWithFunc {
impl EndsWithFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(2),
TypeSignature::Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_string()),
]),
],
Volatility::Immutable,
),
}
}
}
Expand Down
17 changes: 14 additions & 3 deletions datafusion/functions/src/string/octet_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ use arrow::datatypes::DataType;
use std::any::Any;

use crate::utils::utf8_to_int_type;
use datafusion_common::types::logical_string;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
Expand Down Expand Up @@ -55,7 +58,15 @@ impl Default for OctetLengthFunc {
impl OctetLengthFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(1),
TypeSignature::Coercible(vec![TypeSignatureClass::Native(
logical_string(),
)]),
],
Volatility::Immutable,
),
}
}
}
Expand Down
18 changes: 15 additions & 3 deletions datafusion/functions/src/string/starts_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ use arrow::datatypes::DataType;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};

use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Like};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, Like, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Returns true if string starts with prefix.
Expand Down Expand Up @@ -64,7 +67,16 @@ impl Default for StartsWithFunc {
impl StartsWithFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::String(2),
TypeSignature::Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_string()),
]),
],
Volatility::Immutable,
),
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion datafusion/physical-plan/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::EquivalenceProperties;

/// Execution plan for values list based relation (produces constant rows)
#[deprecated(since = "45.0.0", note = "Use `MemoryExec::try_new_as_values` instead")]
#[deprecated(
since = "45.0.0",
note = "Use `MemoryExec::try_new_as_values` or `MemoryExec::try_new_from_batches` instead"
)]
#[derive(Debug, Clone)]
pub struct ValuesExec {
/// The schema
Expand Down
Loading