diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index ecf267185bae..cd9897d43baa 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{ is_table_same_after_sort, TestScalarUDF, }; use arrow_schema::SortOptions; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?); let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index f71df50fce2f..78fbda16c0a0 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{ is_table_same_after_sort, TestScalarUDF, }; use arrow_schema::SortOptions; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?); // a + b let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, @@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; // a + b let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index fc21c620a711..593e1c6c2dca 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{ create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; + let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 30b3c6e2bbeb..a228eb0286aa 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -34,13 +34,12 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result, - ScalarValue, + not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable, - LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, + OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; @@ -819,32 +818,36 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_exprs( - &self, - arg_exprs: &[Expr], - schema: &dyn ExprSchema, - _arg_data_types: &[DataType], - ) -> Result { - if arg_exprs.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + if args.arg_types.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); } - let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = - arg_exprs.get(2) - { - if *idx == 0 || *idx == 1 { - *idx as usize + let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { + // This is for test only, safe to unwrap + let take_idx = take_idx + .unwrap() + .try_as_str() + .unwrap() + .unwrap() + .parse::() + .unwrap(); + + if take_idx == 0 || take_idx == 1 { + take_idx } else { - return plan_err!("The third argument must be 0 or 1, got: {idx}"); + return plan_err!("The third argument must be 0 or 1, got: {take_idx}"); } } else { return plan_err!( "The third argument must be a literal of type int64, but got {:?}", - arg_exprs.get(2) + args.scalar_arguments.get(2) ); }; - arg_exprs.get(take_idx).unwrap().get_type(schema) + Ok(ReturnInfo::new_nullable( + args.arg_types[take_idx].to_owned(), + )) } // The actual implementation @@ -854,7 +857,8 @@ impl ScalarUDFImpl for TakeUDF { _number_rows: usize, ) -> Result { let take_idx = match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1, _ => unreachable!(), }; match &args[take_idx] { @@ -874,9 +878,9 @@ async fn verify_udf_return_type() -> Result<()> { // take(smallint_col, double_col, 1) as take1 // FROM alltypes_plain; let exprs = vec![ - take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)]) + take.call(vec![col("smallint_col"), col("double_col"), lit("0")]) .alias("take0"), - take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)]) + take.call(vec![col("smallint_col"), col("double_col"), lit("1")]) .alias("take1"), ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 25073ca7eaaa..08eb06160c09 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -24,6 +24,7 @@ use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; +use crate::udf::ReturnTypeArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; @@ -145,32 +146,9 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(ScalarFunction { func, args }) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_data_types, - ) - ) - })?; - - // Perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) + Expr::ScalarFunction(_func) => { + let (return_type, _) = self.data_type_and_nullable(schema)?; + Ok(return_type) } Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) @@ -303,8 +281,9 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(ScalarFunction { func, args }) => { - Ok(func.is_nullable(args, input_schema)) + Expr::ScalarFunction(_func) => { + let (_, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(nullable) } Expr::AggregateFunction(AggregateFunction { func, .. }) => { Ok(func.is_nullable()) @@ -415,6 +394,47 @@ impl ExprSchemable for Expr { Expr::WindowFunction(window_function) => { self.data_type_and_nullable_with_window_function(schema, window_function) } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let (arg_types, nullables): (Vec, Vec) = args + .iter() + .map(|e| e.data_type_and_nullable(schema)) + .collect::>>()? + .into_iter() + .unzip(); + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_data_types = data_types_with_scalar_udf(&arg_types, func) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_types, + ) + ) + })?; + + let arguments = args + .iter() + .map(|e| match e { + Expr::Literal(sv) => Some(sv), + _ => None, + }) + .collect::>(); + let args = ReturnTypeArgs { + arg_types: &new_data_types, + scalar_arguments: &arguments, + nullables: &nullables, + }; + + let (return_type, nullable) = + func.return_type_from_args(args)?.into_parts(); + Ok((return_type, nullable)) + } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index a57fd80c48e1..017415da8f23 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -93,7 +93,10 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{ + scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, +}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index ffac82265a00..bb5a405a9352 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -24,7 +24,7 @@ use crate::{ ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; +use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; @@ -182,6 +182,7 @@ impl ScalarUDF { /// /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. + #[allow(deprecated)] pub fn return_type_from_exprs( &self, args: &[Expr], @@ -192,6 +193,10 @@ impl ScalarUDF { self.inner.return_type_from_exprs(args, schema, arg_types) } + pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + self.inner.return_type_from_args(args) + } + /// Do the function rewrite /// /// See [`ScalarUDFImpl::simplify`] for more details. @@ -209,6 +214,7 @@ impl ScalarUDF { self.inner.invoke(args) } + #[allow(deprecated)] pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { self.inner.is_nullable(args, schema) } @@ -342,6 +348,72 @@ pub struct ScalarFunctionArgs<'a> { pub return_type: &'a DataType, } +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`ScalarUDFImpl::return_type_from_args`] for more information +#[derive(Debug)] +pub struct ReturnTypeArgs<'a> { + /// The data types of the arguments to the function + pub arg_types: &'a [DataType], + /// Is argument `i` to the function a scalar (constant) + /// + /// If argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], + /// Can argument `i` (ever) null? + pub nullables: &'a [bool], +} + +/// Return metadata for this function. +/// +/// See [`ScalarUDFImpl::return_type_from_args`] for more information +#[derive(Debug)] +pub struct ReturnInfo { + return_type: DataType, + nullable: bool, +} + +impl ReturnInfo { + pub fn new(return_type: DataType, nullable: bool) -> Self { + Self { + return_type, + nullable, + } + } + + pub fn new_nullable(return_type: DataType) -> Self { + Self { + return_type, + nullable: true, + } + } + + pub fn new_non_nullable(return_type: DataType) -> Self { + Self { + return_type, + nullable: false, + } + } + + pub fn return_type(&self) -> &DataType { + &self.return_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + + pub fn into_parts(self) -> (DataType, bool) { + (self.return_type, self.nullable) + } +} + /// Trait for implementing user defined scalar functions. /// /// This trait exposes the full API for implementing user defined functions and @@ -449,13 +521,23 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_exprs`], + /// If you provide an implementation for [`Self::return_type_from_args`], /// DataFusion will not call `return_type` (this function). In this case it /// is recommended to return [`DataFusionError::Internal`]. /// /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; + #[deprecated(since = "45.0.0", note = "Use `return_type_from_args` instead")] + fn return_type_from_exprs( + &self, + _args: &[Expr], + _schema: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + self.return_type(arg_types) + } + /// What [`DataType`] will be returned by this function, given the /// arguments? /// @@ -481,15 +563,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function must consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_exprs( - &self, - _args: &[Expr], - _schema: &dyn ExprSchema, - arg_types: &[DataType], - ) -> Result { - self.return_type(arg_types) + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + let return_type = self.return_type(args.arg_types)?; + Ok(ReturnInfo::new_nullable(return_type)) } + #[deprecated( + since = "45.0.0", + note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true } @@ -787,6 +869,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { &self.aliases } + #[allow(deprecated)] fn return_type_from_exprs( &self, args: &[Expr], @@ -796,6 +879,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_exprs(args, schema, arg_types) } + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + self.inner.return_type_from_args(args) + } + fn invoke_batch( &self, args: &[ColumnarValue], diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 8e4ae36c9a66..b0fba57460f8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -18,16 +18,17 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` use arrow::datatypes::DataType; +use arrow::error::ArrowError; use datafusion_common::{ - arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, - ExprSchema, Result, ScalarValue, + arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; +use datafusion_common::{exec_datafusion_err, DataFusionError}; use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; @@ -110,22 +111,30 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - // should be using return_type_from_exprs and not calling the default - // implementation - internal_err!("arrow_cast should return type from exprs") + internal_err!("return_type_from_args should be called instead") } - fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) - } - - fn return_type_from_exprs( - &self, - args: &[Expr], - _schema: &dyn ExprSchema, - _arg_types: &[DataType], - ) -> Result { - data_type_from_args(args) + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + let nullable = args.nullables.iter().any(|&nullable| nullable); + + // Length check handled in the signature + debug_assert_eq!(args.scalar_arguments.len(), 2); + + args.scalar_arguments[1] + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + }, + |casted_type| match casted_type.parse::() { + Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + }, + ) } fn invoke_batch( @@ -170,10 +179,10 @@ impl ScalarUDFImpl for ArrowCastFunc { /// Returns the requested type from the arguments fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { - return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); + return exec_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { - return plan_err!( + return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", &args[1] ); @@ -182,7 +191,7 @@ fn data_type_from_args(args: &[Expr]) -> Result { val.parse().map_err(|e| match e { // If the data type cannot be parsed, return a Plan error to signal an // error in the input rather than a more general ArrowError - arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + ArrowError::ParseError(e) => exec_datafusion_err!("{e}"), e => arrow_datafusion_err!(e), }) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index bfd69bab6656..602fe0fd9585 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -19,9 +19,9 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, ExprSchema, Result}; +use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; @@ -76,17 +76,20 @@ impl ScalarUDFImpl for CoalesceFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type_from_args should be called instead") + } + + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // If any the arguments in coalesce is non-null, the result is non-null + let nullable = args.nullables.iter().all(|&nullable| nullable); + let return_type = args + .arg_types .iter() .find_or_first(|d| !d.is_null()) .unwrap() - .clone()) - } - - // If any the arguments in coalesce is non-null, the result is non-null - fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true)) + .clone(); + Ok(ReturnInfo::new(return_type, nullable)) } /// coalesce evaluates to the first value which is not NULL @@ -165,39 +168,3 @@ impl ScalarUDFImpl for CoalesceFunc { self.doc() } } - -#[cfg(test)] -mod test { - use arrow::datatypes::DataType; - - use datafusion_expr::ScalarUDFImpl; - - use crate::core; - - #[test] - fn test_coalesce_return_types() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Date32, DataType::Date32]) - .unwrap(); - assert_eq!(return_type, DataType::Date32); - } - - #[test] - fn test_coalesce_return_types_with_nulls_first() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Null, DataType::Date32]) - .unwrap(); - assert_eq!(return_type, DataType::Date32); - } - - #[test] - fn test_coalesce_return_types_with_nulls_last() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Int64, DataType::Null]) - .unwrap(); - assert_eq!(return_type, DataType::Int64); - } -} diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index f9e700b5dbfd..7c72d4594583 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -21,9 +21,9 @@ use arrow::array::{ use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, Result, ScalarValue, }; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -143,32 +143,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - todo!() + internal_err!("return_type_from_args should be called instead") } - fn return_type_from_exprs( - &self, - args: &[Expr], - schema: &dyn ExprSchema, - _arg_types: &[DataType], - ) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // Length check handled in the signature + debug_assert_eq!(args.scalar_arguments.len(), 2); - let name = match &args[1] { - Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - let data_type = args[0].get_type(schema)?; - match (data_type, name) { + match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -177,26 +159,23 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.data_type().clone()) + Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) }, - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - } - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } } - (DataType::Struct(_), _) => plan_err!( - "Only UTF8 strings are valid as an indexed field in a struct" - ), - (DataType::Null, _) => Ok(DataType::Null), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), + (DataType::Struct(fields),sv) => { + sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || exec_err!("Field name must be a non-empty string"), + |field_name| { + fields.iter().find(|f| f.name() == field_name) + .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) + .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + }) + }, + (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 527ec48fccaa..70c9a425790c 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,7 +18,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -44,11 +44,18 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .chunks_exact(2) .enumerate() .map(|(i, chunk)| { - let name_column = &chunk[0]; let name = match name_column { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, - _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => { + name_scalar + } + // TODO: Implement Display for ColumnarValue + _ => { + return exec_err!( + "named_struct even arguments must be string literals at position {}", + i * 2 + ) + } }; Ok((name, chunk[1].clone())) @@ -148,46 +155,52 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!( - "named_struct: return_type called instead of return_type_from_exprs" - ) + internal_err!("named_struct: return_type called instead of return_type_from_args") } - fn return_type_from_exprs( - &self, - args: &[Expr], - schema: &dyn datafusion_common::ExprSchema, - _arg_types: &[DataType], - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // do not accept 0 arguments. - if args.is_empty() { + if args.scalar_arguments.is_empty() { return exec_err!( "named_struct requires at least one pair of arguments, got 0 instead" ); } - if args.len() % 2 != 0 { + if args.scalar_arguments.len() % 2 != 0 { return exec_err!( "named_struct requires an even number of arguments, got {} instead", - args.len() + args.scalar_arguments.len() ); } - let return_fields = args - .chunks_exact(2) + let names = args + .scalar_arguments + .iter() .enumerate() - .map(|(i, chunk)| { - let name = &chunk[0]; - let value = &chunk[1]; - - if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { - Ok(Field::new(name, value.get_type(schema)?, true)) - } else { - exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) - } - }) + .step_by(2) + .map(|(i, sv)| + sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || + exec_err!( + "{} requires {i}-th (0-indexed) field name as non-empty constant string", + self.name() + ), + Ok + ) + ) + .collect::>>()?; + let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + + let return_fields = names + .into_iter() + .zip(types.into_iter()) + .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(DataType::Struct(Fields::from(return_fields))) + + Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( + return_fields, + )))) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 0f01b6a21b0a..bec378e137c0 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -38,11 +38,11 @@ use datafusion_common::{ }, exec_err, internal_err, types::logical_string, - ExprSchema, Result, ScalarValue, + Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_expr_common::signature::TypeSignatureClass; use datafusion_macros::user_doc; @@ -136,21 +136,30 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_exprs should be called instead") + internal_err!("return_type_from_args should be called instead") } - fn return_type_from_exprs( - &self, - args: &[Expr], - _schema: &dyn ExprSchema, - _arg_types: &[DataType], - ) -> Result { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(part))) if is_epoch(part) => { - Ok(DataType::Float64) - } - _ => Ok(DataType::Int32), - } + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // Length check handled in the signature + debug_assert_eq!(args.scalar_arguments.len(), 2); + + args.scalar_arguments[0] + .and_then(|sv| { + sv.try_as_str() + .flatten() + .filter(|s| !s.is_empty()) + .map(|part| { + if is_epoch(part) { + ReturnInfo::new_nullable(DataType::Float64) + } else { + ReturnInfo::new_nullable(DataType::Int32) + } + }) + }) + .map_or_else( + || exec_err!("{} requires non-empty constant string", self.name()), + Ok, + ) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 425da7ddac29..534b7a4fa638 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -21,10 +21,11 @@ use std::sync::Arc; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; -use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -81,29 +82,39 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_exprs( - &self, - args: &[Expr], - _schema: &dyn ExprSchema, - arg_types: &[DataType], - ) -> Result { - match arg_types.len() { - 1 => Ok(Timestamp(Second, None)), - 2 => match &args[1] { - Expr::Literal(ScalarValue::Utf8(Some(tz))) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), - _ => exec_err!( - "Second argument for `from_unixtime` must be non-null utf8, received {:?}", - arg_types[1]), - }, - _ => exec_err!( - "from_unixtime function requires 1 or 2 arguments, got {}", - arg_types.len() - ), + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // Length check handled in the signature + debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); + + if args.scalar_arguments.len() == 1 { + Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + } else { + args.scalar_arguments[1] + .and_then(|sv| { + sv.try_as_str() + .flatten() + .filter(|s| !s.is_empty()) + .map(|tz| { + ReturnInfo::new_nullable(Timestamp( + Second, + Some(Arc::from(tz.to_string())), + )) + }) + }) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) + }, + Ok, + ) } } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_exprs instead") + internal_err!("call return_type_from_args instead") } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 67cd49b7fd84..76e875737637 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -20,10 +20,11 @@ use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; use std::any::Any; -use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; @@ -76,8 +77,15 @@ impl ScalarUDFImpl for NowFunc { &self.signature } + fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { + Ok(ReturnInfo::new_non_nullable(Timestamp( + Nanosecond, + Some("+00:00".into()), + ))) + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, Some("+00:00".into()))) + internal_err!("return_type_from_args should be called instead") } fn invoke_batch( @@ -106,10 +114,6 @@ impl ScalarUDFImpl for NowFunc { &self.aliases } - fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { - false - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 4e324663dcd1..a72759b5d49a 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -274,11 +274,14 @@ mod tests { }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr}; + use crate::{ + AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + ScalarFunctionExpr, + }; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; - use datafusion_common::{DFSchema, Result}; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -327,28 +330,24 @@ mod tests { let col_d = &col("d", &test_schema)?; let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = &crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_a)], &test_schema, - &[], - &DFSchema::empty(), - )?; - let floor_f = &crate::udf::create_physical_expr( - &test_fun, - &[col("f", &test_schema)?], + )?) as PhysicalExprRef; + let floor_f = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_f)], &test_schema, - &[], - &DFSchema::empty(), - )?; - let exp_a = &crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + )?) as PhysicalExprRef; + let exp_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_a)], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; + let a_plus_b = Arc::new(BinaryExpr::new( Arc::clone(col_a), Operator::Plus, @@ -392,7 +391,7 @@ mod tests { // constants vec![col_e], // requirement [floor(a) ASC], - vec![(floor_a, options)], + vec![(&floor_a, options)], // expected: requirement is satisfied. true, ), @@ -410,7 +409,7 @@ mod tests { // constants vec![col_e], // requirement [floor(f) ASC], (Please note that a=f) - vec![(floor_f, options)], + vec![(&floor_f, options)], // expected: requirement is satisfied. true, ), @@ -449,7 +448,7 @@ mod tests { // constants vec![col_e], // requirement [floor(a) ASC, a+b ASC], - vec![(floor_a, options), (&a_plus_b, options)], + vec![(&floor_a, options), (&a_plus_b, options)], // expected: requirement is satisfied. false, ), @@ -470,7 +469,7 @@ mod tests { // constants vec![col_e], // requirement [exp(a) ASC, a+b ASC], - vec![(exp_a, options), (&a_plus_b, options)], + vec![(&exp_a, options), (&a_plus_b, options)], // expected: requirement is not satisfied. // TODO: If we know that exp function is 1-to-1 function. // we could have deduced that above requirement is satisfied. @@ -490,7 +489,7 @@ mod tests { // constants vec![col_e], // requirement [a ASC, d ASC, floor(a) ASC], - vec![(col_a, options), (col_d, options), (floor_a, options)], + vec![(col_a, options), (col_d, options), (&floor_a, options)], // expected: requirement is satisfied. true, ), @@ -508,7 +507,7 @@ mod tests { // constants vec![col_e], // requirement [a ASC, floor(a) ASC, a + b ASC], - vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + vec![(col_a, options), (&floor_a, options), (&a_plus_b, options)], // expected: requirement is not satisfied. false, ), @@ -529,7 +528,7 @@ mod tests { vec![ (col_a, options), (col_c, options), - (floor_a, options), + (&floor_a, options), (&a_plus_b, options), ], // expected: requirement is not satisfied. @@ -556,7 +555,7 @@ mod tests { (col_a, options), (col_b, options), (col_c, options), - (floor_a, options), + (&floor_a, options), ], // expected: requirement is satisfied. true, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 681484fd6bff..d1e7625525ae 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -143,12 +143,11 @@ mod tests { }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; - use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; + use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; - use datafusion_common::DFSchema; use datafusion_expr::{Operator, ScalarUDF}; #[test] @@ -667,14 +666,13 @@ mod tests { Arc::clone(col_b), )) as Arc; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let round_c = &create_physical_expr( - &test_fun, - &[Arc::clone(col_c)], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + + let round_c = Arc::new(ScalarFunctionExpr::try_new( + test_fun, + vec![Arc::clone(col_c)], &schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; let option_asc = SortOptions { descending: false, @@ -685,7 +683,7 @@ mod tests { (col_b, "b_new".to_string()), (col_a, "a_new".to_string()), (col_c, "c_new".to_string()), - (round_c, "round_c_res".to_string()), + (&round_c, "round_c_res".to_string()), ]; let proj_exprs = proj_exprs .into_iter() diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 4c55f4ddba93..11d6f54a7cc3 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -32,6 +32,7 @@ mod physical_expr; pub mod planner; mod scalar_function; pub mod udf { + #[allow(deprecated)] pub use crate::scalar_function::create_physical_expr; } pub mod utils; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 906ca9fd1093..e05de362bf14 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::scalar_function; +use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, @@ -302,13 +302,11 @@ pub fn create_physical_expr( let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - scalar_function::create_physical_expr( - Arc::clone(func).as_ref(), - &physical_args, + Ok(Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(func), + physical_args, input_schema, - args, - input_dfschema, - ) + )?)) } Expr::Between(Between { expr, diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 0ae4115de67a..936adbc098d6 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,6 +34,7 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::Hash; use std::sync::Arc; +use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; @@ -43,7 +44,9 @@ use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; +use datafusion_expr::{ + expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, +}; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -83,6 +86,49 @@ impl ScalarFunctionExpr { } } + /// Create a new Scalar function + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + ) -> Result { + let name = fun.name().to_string(); + let arg_types = args + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types_with_scalar_udf(&arg_types, &fun)?; + + let nullables = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; + + let arguments = args + .iter() + .map(|e| { + e.as_any() + .downcast_ref::() + .map(|literal| literal.value()) + }) + .collect::>(); + let ret_args = ReturnTypeArgs { + arg_types: &arg_types, + scalar_arguments: &arguments, + nullables: &nullables, + }; + let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + Ok(Self { + fun, + name, + args, + return_type, + nullable, + }) + } + /// Get the scalar function implementation pub fn fun(&self) -> &ScalarUDF { &self.fun @@ -218,6 +264,7 @@ impl PhysicalExpr for ScalarFunctionExpr { } /// Create a physical expression for the UDF. +#[deprecated(since = "45.0.0", note = "use ScalarFunctionExpr::new() instead")] pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 77b10b41ccb3..654218531f1d 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) +query error DataFusion error: Execution error: arrow_cast requires its second argument to be a non\-empty constant string SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index b05e86e5ea91..d671798b7d0f 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -151,19 +151,19 @@ query error DataFusion error: Execution error: named_struct requires an even num select named_struct('a', 1, 'b'); # error on even argument not a string literal #1 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(1\) instead at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(1, 'a'); # error on even argument not a string literal #2 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(0\) instead at position 2 +query error DataFusion error: Execution error: named_struct requires 2\-th \(0\-indexed\) field name as non\-empty constant string select named_struct('corret', 1, 0, 'wrong'); # error on even argument not a string literal #3 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.a instead at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(values.a, 'a') from values; # error on even argument not a string literal #4 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.c instead at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(values.c, 'c') from values; # named_struct with mixed scalar and array values #1