Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two new methods in ScalarFunction return_type_from_args and is_nullable_from_args_nullable #14094

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand Down Expand Up @@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
29 changes: 14 additions & 15 deletions datafusion/core/tests/fuzz_cases/equivalence/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down Expand Up @@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down
17 changes: 9 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{
create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort,
TestScalarUDF,
};
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;

let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
19 changes: 17 additions & 2 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
};
use crate::udf::ReturnTypeArgs;
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
Result, TableReference,
Result, ScalarValue, TableReference,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use std::collections::HashMap;
Expand Down Expand Up @@ -168,9 +169,23 @@ impl ExprSchemable for Expr {
)
})?;

let arguments = args
.iter()
.map(|e| match e {
Expr::Literal(ScalarValue::Utf8(s)) => {
s.clone().unwrap_or_default()
}
_ => "".to_string(),
})
.collect::<Vec<_>>();
let args = ReturnTypeArgs {
arg_types: &new_data_types,
arguments: &arguments,
};

// Perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
Ok(func.return_type_from_args(args)?)
}
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf::{
scalar_doc_sections, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
63 changes: 63 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl ScalarUDF {
///
///
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
#[allow(deprecated)]
pub fn return_type_from_exprs(
&self,
args: &[Expr],
Expand All @@ -192,6 +193,10 @@ impl ScalarUDF {
self.inner.return_type_from_exprs(args, schema, arg_types)
}

pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> {
self.inner.return_type_from_args(args)
}

/// Do the function rewrite
///
/// See [`ScalarUDFImpl::simplify`] for more details.
Expand All @@ -209,10 +214,15 @@ impl ScalarUDF {
self.inner.invoke(args)
}

#[allow(deprecated)]
pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
self.inner.is_nullable(args, schema)
}

pub fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove Expr dependency

self.inner.is_nullable_from_args_nullable(args_nullables)
}

pub fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down Expand Up @@ -342,6 +352,14 @@ pub struct ScalarFunctionArgs<'a> {
pub return_type: &'a DataType,
}

#[derive(Debug)]
pub struct ReturnTypeArgs<'a> {
/// The data types of the arguments to the function
pub arg_types: &'a [DataType],
/// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string
pub arguments: &'a [String],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better name 🤔 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to unify the argument handling so that both return type and nullability are returned the same?

I wonder if it would somehow be possible to add the input nullable information here too 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also not sure about only supporting string args, that is likely a regression in behavior for some users (For example, maybe they look for constant integers as well)

}

/// Trait for implementing user defined scalar functions.
///
/// This trait exposes the full API for implementing user defined functions and
Expand Down Expand Up @@ -481,6 +499,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// This function must consistently return the same type for the same
/// logical input even if the input is simplified (e.g. it must return the same
/// value for `('foo' | 'bar')` as it does for ('foobar').
#[deprecated(since = "45.0.0", note = "Use `return_type_from_args` instead")]
fn return_type_from_exprs(
&self,
_args: &[Expr],
Expand All @@ -490,10 +509,49 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
self.return_type(arg_types)
}

/// What [`DataType`] will be returned by this function, given the
/// arguments?
///
/// Note most UDFs should implement [`Self::return_type`] and not this
/// function. The output type for most functions only depends on the types
/// of their inputs (e.g. `sqrt(f32)` is always `f32`).
///
/// By default, this function calls [`Self::return_type`] with the
/// types of each argument.
///
/// This method can be overridden for functions that return different
/// *types* based on the *values* of their arguments.
///
/// For example, the following two function calls get the same argument
/// types (something and a `Utf8` string) but return different types based
/// on the value of the second argument:
///
/// * `arrow_cast(x, 'Int16')` --> `Int16`
/// * `arrow_cast(x, 'Float32')` --> `Float32`
///
/// # Notes:
///
/// This function must consistently return the same type for the same
/// logical input even if the input is simplified (e.g. it must return the same
/// value for `('foo' | 'bar')` as it does for ('foobar').
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> {
self.return_type(args.arg_types)
}

#[deprecated(
since = "45.0.0",
note = "Use `is_nullable_from_args_nullable` instead"
)]
fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
true
}

/// `is_nullable` from pre-computed nullable flags.
/// It has less dependencies on the input arguments.
fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool {
true
}

/// Invoke the function on `args`, returning the appropriate result
///
/// Note: This method is deprecated and will be removed in future releases.
Expand Down Expand Up @@ -787,6 +845,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
&self.aliases
}

#[allow(deprecated)]
fn return_type_from_exprs(
&self,
args: &[Expr],
Expand All @@ -796,6 +855,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.return_type_from_exprs(args, schema, arg_types)
}

fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> {
self.inner.return_type_from_args(args)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down
45 changes: 30 additions & 15 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`

use arrow::datatypes::DataType;
use datafusion_common::DataFusionError;
use datafusion_common::{
arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError,
ExprSchema, Result, ScalarValue,
arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, Result,
ScalarValue,
};
use std::any::Any;
use std::sync::OnceLock;

use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature,
ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature,
Volatility,
};

Expand Down Expand Up @@ -86,22 +87,36 @@ impl ScalarUDFImpl for ArrowCastFunc {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Copy link
Contributor Author

@jayzhan211 jayzhan211 Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this change looks good, we can deprecate this too

// should be using return_type_from_exprs and not calling the default
// implementation
internal_err!("arrow_cast should return type from exprs")
internal_err!("return_type_from_args should be called instead")
}

fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool {
args_nullables.iter().any(|&nullable| nullable)
}

fn return_type_from_exprs(
&self,
args: &[Expr],
_schema: &dyn ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
data_type_from_args(args)
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> {
if args.arguments.len() != 2 {
return plan_err!(
"{} needs 2 arguments, {} provided",
self.name(),
args.arguments.len()
);
}

let val = &args.arguments[1];
if val.is_empty() {
return plan_err!(
"{} requires its second argument to be a constant string",
self.name()
);
};

val.parse().map_err(|e| match e {
// If the data type cannot be parsed, return a Plan error to signal an
// error in the input rather than a more general ArrowError
arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"),
e => arrow_datafusion_err!(e),
})
}

fn invoke_batch(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/functions/src/core/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use arrow::array::{new_null_array, BooleanArray};
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, ExprSchema, Result};
use datafusion_common::{exec_err, Result};
use datafusion_expr::binary::try_type_union_resolution;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL;
use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable};
use datafusion_expr::{ColumnarValue, Documentation};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;
use std::any::Any;
Expand Down Expand Up @@ -69,8 +69,8 @@ impl ScalarUDFImpl for CoalesceFunc {
}

// If any the arguments in coalesce is non-null, the result is non-null
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true))
fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool {
args_nullables.iter().all(|&nullable| nullable)
}

/// coalesce evaluates to the first value which is not NULL
Expand Down
Loading
Loading