From 63949c0a25d24aaaa07f110dbff8a9d14b0c9d49 Mon Sep 17 00:00:00 2001 From: jasonnnli Date: Tue, 14 Nov 2023 23:02:21 +0800 Subject: [PATCH] feat: refactor ReturnTypeFunction --- datafusion-examples/examples/complex_udf.rs | 12 ++--- datafusion/expr/src/expr_schema.rs | 33 ++++++++---- datafusion/expr/src/function.rs | 13 +++-- datafusion/expr/src/udaf.rs | 5 +- datafusion/expr/src/udwf.rs | 5 +- datafusion/expr/src/window_function.rs | 53 +++++++++++-------- datafusion/physical-expr/src/udf.rs | 12 ++--- datafusion/physical-expr/src/utils.rs | 13 +++++ datafusion/physical-plan/src/udaf.rs | 8 ++- datafusion/physical-plan/src/windows/mod.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 7 ++- 11 files changed, 104 insertions(+), 61 deletions(-) diff --git a/datafusion-examples/examples/complex_udf.rs b/datafusion-examples/examples/complex_udf.rs index 185524a284f48..6e285fb6c8dce 100644 --- a/datafusion-examples/examples/complex_udf.rs +++ b/datafusion-examples/examples/complex_udf.rs @@ -27,7 +27,7 @@ use datafusion::{ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::ScalarValue; -use datafusion_expr::function::ReturnTypeFactory; +use datafusion_expr::function::{ConstantArg, ReturnTypeFactory}; use datafusion_expr::{ ColumnarValue, ScalarFunctionImplementation, ScalarUDF, Signature, }; @@ -87,11 +87,11 @@ async fn main() -> Result<()> { impl ReturnTypeFactory for ReturnType { fn infer( &self, - data_types: &[DataType], - literals: &[(usize, ScalarValue)], + input_types: &[DataType], + constant_args: &[ConstantArg], ) -> Result> { - assert_eq!(literals.len(), 1); - let (idx, val) = &literals[0]; + assert_eq!(constant_args.len(), 1); + let (idx, val) = &constant_args[0]; assert_eq!(idx, &2); let take_idx = match val { @@ -99,7 +99,7 @@ async fn main() -> Result<()> { _ => unreachable!(), }; - Ok(Arc::new(data_types[take_idx].clone())) + Ok(Arc::new(input_types[take_idx].clone())) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f15437c3624d8..81b8c943691d1 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,6 +22,7 @@ use crate::expr::{ TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; +use crate::function::ConstantArg; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::data_types; use crate::{utils, LogicalPlan, Projection, Subquery}; @@ -87,17 +88,10 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - - let mut literals = vec![]; - args.iter().enumerate().for_each(|(i, arg)| { - if let Expr::Literal(scalar_value) = arg { - literals.push((i, scalar_value.clone())); - } - }); - + let constant_args = extract_constant_args(&args); Ok(fun .return_type - .infer(&data_types, &literals)? + .infer(&data_types, &constant_args)? .as_ref() .clone()) } @@ -126,7 +120,8 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + let constant_args = extract_constant_args(&args); + fun.return_type(&data_types, &constant_args) } Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { let data_types = args @@ -140,7 +135,12 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + let constant_args = extract_constant_args(&args); + Ok(fun + .return_type + .infer(&data_types, &constant_args)? + .as_ref() + .clone()) } Expr::Not(_) | Expr::IsNull(_) @@ -355,6 +355,17 @@ impl ExprSchemable for Expr { } } +/// Extract constant from input Expr +pub fn extract_constant_args(args: &[Expr]) -> Vec { + let mut constant_args = vec![]; + args.iter().enumerate().for_each(|(i, arg)| { + if let Expr::Literal(scalar_value) = arg { + constant_args.push((i, scalar_value.clone())); + } + }); + constant_args +} + /// return the schema [`Field`] for the type referenced by `get_indexed_field` fn field_for_index( expr: &Expr, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1181a80853d52..e2b68dc4d2996 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -36,12 +36,15 @@ use strum::IntoEnumIterator; pub type ScalarFunctionImplementation = Arc Result + Send + Sync>; +/// Constant argument, (arg index, constant value). +pub type ConstantArg = (usize, ScalarValue); + /// Factory that returns the functions's return type given the input argument types and constant arguments pub trait ReturnTypeFactory: Send + Sync { fn infer( &self, - data_types: &[DataType], - literals: &[(usize, ScalarValue)], + input_types: &[DataType], + constant_args: &[ConstantArg], ) -> Result>; } @@ -52,10 +55,10 @@ pub type ReturnTypeFunction = impl ReturnTypeFactory for ReturnTypeFunction { fn infer( &self, - data_types: &[DataType], - _literals: &[(usize, ScalarValue)], + input_types: &[DataType], + _constant_args: &[ConstantArg], ) -> Result> { - self(data_types) + self(input_types) } } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 84e238a1215b2..2bb330fd887fd 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,7 @@ //! Udaf module contains functions and structs supporting user-defined aggregate functions. +use crate::function::ReturnTypeFactory; use crate::Expr; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, @@ -50,7 +51,7 @@ pub struct AggregateUDF { /// Signature (input arguments) pub signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + pub return_type: Arc, /// actual implementation pub accumulator: AccumulatorFactoryFunction, /// the accumulator's state's description as a function of the return type @@ -94,7 +95,7 @@ impl AggregateUDF { Self { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), + return_type: Arc::new(return_type.clone()), accumulator: accumulator.clone(), state_type: state_type.clone(), } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0a2a8205a080..cf9163ad5a77a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -22,6 +22,7 @@ use std::{ sync::Arc, }; +use crate::function::ReturnTypeFactory; use crate::{ Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, }; @@ -39,7 +40,7 @@ pub struct WindowUDF { /// signature pub signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + pub return_type: Arc, /// Return the partition evaluator pub partition_evaluator_factory: PartitionEvaluatorFactory, } @@ -88,7 +89,7 @@ impl WindowUDF { Self { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), + return_type: Arc::new(return_type.clone()), partition_evaluator_factory: partition_evaluator_factory.clone(), } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 463cceafeb6ea..226e001ca106c 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -21,6 +21,7 @@ //! see also use crate::aggregate_function::AggregateFunction; +use crate::function::ConstantArg; use crate::type_coercion::functions::data_types; use crate::utils; use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; @@ -160,23 +161,31 @@ pub fn return_type( fun: &WindowFunction, input_expr_types: &[DataType], ) -> Result { - fun.return_type(input_expr_types) + fun.return_type(input_expr_types, &[]) } impl WindowFunction { /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + constant_args: &[ConstantArg], + ) -> Result { match self { WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), WindowFunction::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) } - WindowFunction::AggregateUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - WindowFunction::WindowUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } + WindowFunction::AggregateUDF(fun) => Ok(fun + .return_type + .infer(input_expr_types, constant_args)? + .as_ref() + .clone()), + WindowFunction::WindowUDF(fun) => Ok(fun + .return_type + .infer(input_expr_types, constant_args)? + .as_ref() + .clone()), } } } @@ -286,10 +295,10 @@ mod tests { #[test] fn test_count_return_type() -> Result<()> { let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &vec![])?; assert_eq!(DataType::Int64, observed); - let observed = fun.return_type(&[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64], &vec![])?; assert_eq!(DataType::Int64, observed); Ok(()) @@ -298,10 +307,10 @@ mod tests { #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &vec![])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64], &vec![])?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -310,10 +319,10 @@ mod tests { #[test] fn test_last_value_return_type() -> Result<()> { let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &vec![])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &vec![])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -322,10 +331,10 @@ mod tests { #[test] fn test_lead_return_type() -> Result<()> { let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &vec![])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &vec![])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -334,10 +343,10 @@ mod tests { #[test] fn test_lag_return_type() -> Result<()> { let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -346,10 +355,10 @@ mod tests { #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64], &[])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -358,7 +367,7 @@ mod tests { #[test] fn test_percent_rank_return_type() -> Result<()> { let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -367,7 +376,7 @@ mod tests { #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index dd86faa03ed69..e55ddd0a9dcad 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -16,6 +16,7 @@ // under the License. //! UDF support +use crate::utils::extract_constant_args; use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow::datatypes::Schema; use datafusion_common::Result; @@ -33,21 +34,14 @@ pub fn create_physical_expr( .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - - let mut literals = vec![]; - input_phy_exprs.iter().enumerate().for_each(|(i, expr)| { - if let Some(literal) = expr.as_any().downcast_ref::() - { - literals.push((i, literal.value().clone())); - } - }); + let constant_args = extract_constant_args(input_phy_exprs); Ok(Arc::new(ScalarFunctionExpr::new( &fun.name, fun.fun.clone(), input_phy_exprs.to_vec(), fun.return_type - .infer(&input_exprs_types, &literals)? + .infer(&input_exprs_types, &constant_args)? .as_ref(), None, ))) diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 2f4ee89463a85..760129266df03 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -31,6 +31,7 @@ use datafusion_common::tree_node::{ use datafusion_common::Result; use datafusion_expr::Operator; +use datafusion_expr::function::ConstantArg; use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -359,6 +360,18 @@ pub fn merge_vectors( .collect() } +/// Extract constant from input PhysicalExprs +pub fn extract_constant_args(args: &[Arc]) -> Vec { + let mut constant_args = vec![]; + args.iter().enumerate().for_each(|(i, expr)| { + if let Some(literal) = expr.as_any().downcast_ref::() + { + constant_args.push((i, literal.value().clone())); + } + }); + constant_args +} + #[cfg(test)] mod tests { use std::fmt::{Display, Formatter}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 7cc3cc7d59fed..bee9f6418ac7c 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -32,6 +32,7 @@ pub use datafusion_expr::AggregateUDF; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; +use datafusion_physical_expr::utils::extract_constant_args; use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. @@ -46,11 +47,16 @@ pub fn create_aggregate_expr( .iter() .map(|arg| arg.data_type(input_schema)) .collect::>>()?; + let constant_args = extract_constant_args(input_phy_exprs); Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), - data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), + data_type: fun + .return_type + .infer(&input_exprs_types, &constant_args)? + .as_ref() + .clone(), name: name.into(), })) } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index b6ed6e482ff50..215a03c16ef0b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -48,6 +48,7 @@ mod bounded_window_agg_exec; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_physical_expr::utils::extract_constant_args; pub use window_agg_exec::WindowAggExec; pub use datafusion_physical_expr::window::{ @@ -253,9 +254,10 @@ fn create_udwf_window_expr( .iter() .map(|arg| arg.data_type(input_schema)) .collect::>()?; + let constant_args = extract_constant_args(args); // figure out the output type - let data_type = (fun.return_type)(&input_types)?; + let data_type = fun.return_type.infer(&input_types, &constant_args)?; Ok(Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 97c553dc04e64..8e614c6e058fa 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -45,8 +45,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + Expr, LogicalPlan, Operator, PartitionEvaluator, ReturnTypeFunction, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1529,6 +1530,8 @@ fn roundtrip_window() { Ok(Arc::new(arg_types[0].clone())) } + let return_type: ReturnTypeFunction = Arc::new(return_type); + fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) }