-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add two new methods in ScalarFunction return_type_from_args
and is_nullable_from_args_nullable
#14094
base: main
Are you sure you want to change the base?
Add two new methods in ScalarFunction return_type_from_args
and is_nullable_from_args_nullable
#14094
Changes from all commits
6b00b9a
b079be3
8c9ee8c
6df7476
fe7f6a5
4da4c71
de4b484
02a64ce
f26ce70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,6 +182,7 @@ impl ScalarUDF { | |
/// | ||
/// | ||
/// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. | ||
#[allow(deprecated)] | ||
pub fn return_type_from_exprs( | ||
&self, | ||
args: &[Expr], | ||
|
@@ -192,6 +193,10 @@ impl ScalarUDF { | |
self.inner.return_type_from_exprs(args, schema, arg_types) | ||
} | ||
|
||
pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> { | ||
self.inner.return_type_from_args(args) | ||
} | ||
|
||
/// Do the function rewrite | ||
/// | ||
/// See [`ScalarUDFImpl::simplify`] for more details. | ||
|
@@ -209,10 +214,15 @@ impl ScalarUDF { | |
self.inner.invoke(args) | ||
} | ||
|
||
#[allow(deprecated)] | ||
pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { | ||
self.inner.is_nullable(args, schema) | ||
} | ||
|
||
pub fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { | ||
self.inner.is_nullable_from_args_nullable(args_nullables) | ||
} | ||
|
||
pub fn invoke_batch( | ||
&self, | ||
args: &[ColumnarValue], | ||
|
@@ -342,6 +352,14 @@ pub struct ScalarFunctionArgs<'a> { | |
pub return_type: &'a DataType, | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct ReturnTypeArgs<'a> { | ||
/// The data types of the arguments to the function | ||
pub arg_types: &'a [DataType], | ||
/// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string | ||
pub arguments: &'a [String], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better name 🤔 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to unify the argument handling so that both return type and nullability are returned the same? I wonder if it would somehow be possible to add the input nullable information here too 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am also not sure about only supporting string args, that is likely a regression in behavior for some users (For example, maybe they look for constant integers as well) |
||
} | ||
|
||
/// Trait for implementing user defined scalar functions. | ||
/// | ||
/// This trait exposes the full API for implementing user defined functions and | ||
|
@@ -481,6 +499,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
/// This function must consistently return the same type for the same | ||
/// logical input even if the input is simplified (e.g. it must return the same | ||
/// value for `('foo' | 'bar')` as it does for ('foobar'). | ||
#[deprecated(since = "45.0.0", note = "Use `return_type_from_args` instead")] | ||
fn return_type_from_exprs( | ||
&self, | ||
_args: &[Expr], | ||
|
@@ -490,10 +509,49 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
self.return_type(arg_types) | ||
} | ||
|
||
/// What [`DataType`] will be returned by this function, given the | ||
/// arguments? | ||
/// | ||
/// Note most UDFs should implement [`Self::return_type`] and not this | ||
/// function. The output type for most functions only depends on the types | ||
/// of their inputs (e.g. `sqrt(f32)` is always `f32`). | ||
/// | ||
/// By default, this function calls [`Self::return_type`] with the | ||
/// types of each argument. | ||
/// | ||
/// This method can be overridden for functions that return different | ||
/// *types* based on the *values* of their arguments. | ||
/// | ||
/// For example, the following two function calls get the same argument | ||
/// types (something and a `Utf8` string) but return different types based | ||
/// on the value of the second argument: | ||
/// | ||
/// * `arrow_cast(x, 'Int16')` --> `Int16` | ||
/// * `arrow_cast(x, 'Float32')` --> `Float32` | ||
/// | ||
/// # Notes: | ||
/// | ||
/// This function must consistently return the same type for the same | ||
/// logical input even if the input is simplified (e.g. it must return the same | ||
/// value for `('foo' | 'bar')` as it does for ('foobar'). | ||
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> { | ||
self.return_type(args.arg_types) | ||
} | ||
|
||
#[deprecated( | ||
since = "45.0.0", | ||
note = "Use `is_nullable_from_args_nullable` instead" | ||
)] | ||
fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { | ||
true | ||
} | ||
|
||
/// `is_nullable` from pre-computed nullable flags. | ||
/// It has less dependencies on the input arguments. | ||
fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { | ||
true | ||
} | ||
|
||
/// Invoke the function on `args`, returning the appropriate result | ||
/// | ||
/// Note: This method is deprecated and will be removed in future releases. | ||
|
@@ -787,6 +845,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { | |
&self.aliases | ||
} | ||
|
||
#[allow(deprecated)] | ||
fn return_type_from_exprs( | ||
&self, | ||
args: &[Expr], | ||
|
@@ -796,6 +855,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { | |
self.inner.return_type_from_exprs(args, schema, arg_types) | ||
} | ||
|
||
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> { | ||
self.inner.return_type_from_args(args) | ||
} | ||
|
||
fn invoke_batch( | ||
&self, | ||
args: &[ColumnarValue], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,17 +18,18 @@ | |
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` | ||
|
||
use arrow::datatypes::DataType; | ||
use datafusion_common::DataFusionError; | ||
use datafusion_common::{ | ||
arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, | ||
ExprSchema, Result, ScalarValue, | ||
arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, Result, | ||
ScalarValue, | ||
}; | ||
use std::any::Any; | ||
use std::sync::OnceLock; | ||
|
||
use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; | ||
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; | ||
use datafusion_expr::{ | ||
ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, | ||
ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, | ||
Volatility, | ||
}; | ||
|
||
|
@@ -86,22 +87,36 @@ impl ScalarUDFImpl for ArrowCastFunc { | |
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this change looks good, we can deprecate this too |
||
// should be using return_type_from_exprs and not calling the default | ||
// implementation | ||
internal_err!("arrow_cast should return type from exprs") | ||
internal_err!("return_type_from_args should be called instead") | ||
} | ||
|
||
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { | ||
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) | ||
fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { | ||
args_nullables.iter().any(|&nullable| nullable) | ||
} | ||
|
||
fn return_type_from_exprs( | ||
&self, | ||
args: &[Expr], | ||
_schema: &dyn ExprSchema, | ||
_arg_types: &[DataType], | ||
) -> Result<DataType> { | ||
data_type_from_args(args) | ||
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<DataType> { | ||
if args.arguments.len() != 2 { | ||
return plan_err!( | ||
"{} needs 2 arguments, {} provided", | ||
self.name(), | ||
args.arguments.len() | ||
); | ||
} | ||
|
||
let val = &args.arguments[1]; | ||
if val.is_empty() { | ||
return plan_err!( | ||
"{} requires its second argument to be a constant string", | ||
self.name() | ||
); | ||
}; | ||
|
||
val.parse().map_err(|e| match e { | ||
// If the data type cannot be parsed, return a Plan error to signal an | ||
// error in the input rather than a more general ArrowError | ||
arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), | ||
e => arrow_datafusion_err!(e), | ||
}) | ||
} | ||
|
||
fn invoke_batch( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
Expr
dependency