Skip to content

Commit

Permalink
feat: refactor ReturnTypeFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonnnli committed Nov 14, 2023
1 parent 681279a commit 63949c0
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 61 deletions.
12 changes: 6 additions & 6 deletions datafusion-examples/examples/complex_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion::{
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::ScalarValue;
use datafusion_expr::function::ReturnTypeFactory;
use datafusion_expr::function::{ConstantArg, ReturnTypeFactory};
use datafusion_expr::{
ColumnarValue, ScalarFunctionImplementation, ScalarUDF, Signature,
};
Expand Down Expand Up @@ -87,19 +87,19 @@ async fn main() -> Result<()> {
impl ReturnTypeFactory for ReturnType {
fn infer(
&self,
data_types: &[DataType],
literals: &[(usize, ScalarValue)],
input_types: &[DataType],
constant_args: &[ConstantArg],
) -> Result<Arc<DataType>> {
assert_eq!(literals.len(), 1);
let (idx, val) = &literals[0];
assert_eq!(constant_args.len(), 1);
let (idx, val) = &constant_args[0];
assert_eq!(idx, &2);

let take_idx = match val {
ScalarValue::Int64(Some(v)) if v < &2 => *v as usize,
_ => unreachable!(),
};

Ok(Arc::new(data_types[take_idx].clone()))
Ok(Arc::new(input_types[take_idx].clone()))
}
}

Expand Down
33 changes: 22 additions & 11 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::expr::{
TryCast, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::function::ConstantArg;
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types;
use crate::{utils, LogicalPlan, Projection, Subquery};
Expand Down Expand Up @@ -87,17 +88,10 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let mut literals = vec![];
args.iter().enumerate().for_each(|(i, arg)| {
if let Expr::Literal(scalar_value) = arg {
literals.push((i, scalar_value.clone()));
}
});

let constant_args = extract_constant_args(&args);
Ok(fun
.return_type
.infer(&data_types, &literals)?
.infer(&data_types, &constant_args)?
.as_ref()
.clone())
}
Expand Down Expand Up @@ -126,7 +120,8 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
let constant_args = extract_constant_args(&args);
fun.return_type(&data_types, &constant_args)
}
Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => {
let data_types = args
Expand All @@ -140,7 +135,12 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
let constant_args = extract_constant_args(&args);
Ok(fun
.return_type
.infer(&data_types, &constant_args)?
.as_ref()
.clone())
}
Expr::Not(_)
| Expr::IsNull(_)
Expand Down Expand Up @@ -355,6 +355,17 @@ impl ExprSchemable for Expr {
}
}

/// Extract constant from input Expr
pub fn extract_constant_args(args: &[Expr]) -> Vec<ConstantArg> {
let mut constant_args = vec![];
args.iter().enumerate().for_each(|(i, arg)| {
if let Expr::Literal(scalar_value) = arg {
constant_args.push((i, scalar_value.clone()));
}
});
constant_args
}

/// return the schema [`Field`] for the type referenced by `get_indexed_field`
fn field_for_index<S: ExprSchema>(
expr: &Expr,
Expand Down
13 changes: 8 additions & 5 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ use strum::IntoEnumIterator;
pub type ScalarFunctionImplementation =
Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;

/// Constant argument, (arg index, constant value).
pub type ConstantArg = (usize, ScalarValue);

/// Factory that returns the functions's return type given the input argument types and constant arguments
pub trait ReturnTypeFactory: Send + Sync {
fn infer(
&self,
data_types: &[DataType],
literals: &[(usize, ScalarValue)],
input_types: &[DataType],
constant_args: &[ConstantArg],
) -> Result<Arc<DataType>>;
}

Expand All @@ -52,10 +55,10 @@ pub type ReturnTypeFunction =
impl ReturnTypeFactory for ReturnTypeFunction {
fn infer(
&self,
data_types: &[DataType],
_literals: &[(usize, ScalarValue)],
input_types: &[DataType],
_constant_args: &[ConstantArg],
) -> Result<Arc<DataType>> {
self(data_types)
self(input_types)
}
}

Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Udaf module contains functions and structs supporting user-defined aggregate functions.
use crate::function::ReturnTypeFactory;
use crate::Expr;
use crate::{
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
Expand Down Expand Up @@ -50,7 +51,7 @@ pub struct AggregateUDF {
/// Signature (input arguments)
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
pub return_type: Arc<dyn ReturnTypeFactory>,
/// actual implementation
pub accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
Expand Down Expand Up @@ -94,7 +95,7 @@ impl AggregateUDF {
Self {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
return_type: Arc::new(return_type.clone()),
accumulator: accumulator.clone(),
state_type: state_type.clone(),
}
Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::{
sync::Arc,
};

use crate::function::ReturnTypeFactory;
use crate::{
Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};
Expand All @@ -39,7 +40,7 @@ pub struct WindowUDF {
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
pub return_type: Arc<dyn ReturnTypeFactory>,
/// Return the partition evaluator
pub partition_evaluator_factory: PartitionEvaluatorFactory,
}
Expand Down Expand Up @@ -88,7 +89,7 @@ impl WindowUDF {
Self {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
return_type: Arc::new(return_type.clone()),
partition_evaluator_factory: partition_evaluator_factory.clone(),
}
}
Expand Down
53 changes: 31 additions & 22 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//! see also <https://www.postgresql.org/docs/current/functions-window.html>
use crate::aggregate_function::AggregateFunction;
use crate::function::ConstantArg;
use crate::type_coercion::functions::data_types;
use crate::utils;
use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF};
Expand Down Expand Up @@ -160,23 +161,31 @@ pub fn return_type(
fun: &WindowFunction,
input_expr_types: &[DataType],
) -> Result<DataType> {
fun.return_type(input_expr_types)
fun.return_type(input_expr_types, &[])
}

impl WindowFunction {
/// Returns the datatype of the window function
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
pub fn return_type(
&self,
input_expr_types: &[DataType],
constant_args: &[ConstantArg],
) -> Result<DataType> {
match self {
WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types),
WindowFunction::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunction::AggregateUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::WindowUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::AggregateUDF(fun) => Ok(fun
.return_type
.infer(input_expr_types, constant_args)?
.as_ref()
.clone()),
WindowFunction::WindowUDF(fun) => Ok(fun
.return_type
.infer(input_expr_types, constant_args)?
.as_ref()
.clone()),
}
}
}
Expand Down Expand Up @@ -286,10 +295,10 @@ mod tests {
#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &vec![])?;
assert_eq!(DataType::Int64, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
let observed = fun.return_type(&[DataType::UInt64], &vec![])?;
assert_eq!(DataType::Int64, observed);

Ok(())
Expand All @@ -298,10 +307,10 @@ mod tests {
#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &vec![])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
let observed = fun.return_type(&[DataType::UInt64], &vec![])?;
assert_eq!(DataType::UInt64, observed);

Ok(())
Expand All @@ -310,10 +319,10 @@ mod tests {
#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &vec![])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &vec![])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -322,10 +331,10 @@ mod tests {
#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &vec![])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &vec![])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -334,10 +343,10 @@ mod tests {
#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &[])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -346,10 +355,10 @@ mod tests {
#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64], &[])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?;
let observed = fun.return_type(&[DataType::Float64, DataType::UInt64], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -358,7 +367,7 @@ mod tests {
#[test]
fn test_percent_rank_return_type() -> Result<()> {
let fun = find_df_window_func("percent_rank").unwrap();
let observed = fun.return_type(&[])?;
let observed = fun.return_type(&[], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -367,7 +376,7 @@ mod tests {
#[test]
fn test_cume_dist_return_type() -> Result<()> {
let fun = find_df_window_func("cume_dist").unwrap();
let observed = fun.return_type(&[])?;
let observed = fun.return_type(&[], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand Down
12 changes: 3 additions & 9 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

//! UDF support
use crate::utils::extract_constant_args;
use crate::{PhysicalExpr, ScalarFunctionExpr};
use arrow::datatypes::Schema;
use datafusion_common::Result;
Expand All @@ -33,21 +34,14 @@ pub fn create_physical_expr(
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

let mut literals = vec![];
input_phy_exprs.iter().enumerate().for_each(|(i, expr)| {
if let Some(literal) = expr.as_any().downcast_ref::<crate::expressions::Literal>()
{
literals.push((i, literal.value().clone()));
}
});
let constant_args = extract_constant_args(input_phy_exprs);

Ok(Arc::new(ScalarFunctionExpr::new(
&fun.name,
fun.fun.clone(),
input_phy_exprs.to_vec(),
fun.return_type
.infer(&input_exprs_types, &literals)?
.infer(&input_exprs_types, &constant_args)?
.as_ref(),
None,
)))
Expand Down
13 changes: 13 additions & 0 deletions datafusion/physical-expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion_common::tree_node::{
use datafusion_common::Result;
use datafusion_expr::Operator;

use datafusion_expr::function::ConstantArg;
use itertools::Itertools;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::StableGraph;
Expand Down Expand Up @@ -359,6 +360,18 @@ pub fn merge_vectors(
.collect()
}

/// Extract constant from input PhysicalExprs
pub fn extract_constant_args(args: &[Arc<dyn PhysicalExpr>]) -> Vec<ConstantArg> {
let mut constant_args = vec![];
args.iter().enumerate().for_each(|(i, expr)| {
if let Some(literal) = expr.as_any().downcast_ref::<crate::expressions::Literal>()
{
constant_args.push((i, literal.value().clone()));
}
});
constant_args
}

#[cfg(test)]
mod tests {
use std::fmt::{Display, Formatter};
Expand Down
Loading

0 comments on commit 63949c0

Please sign in to comment.