From 4c10ce7bcdbde22dc6802d39ac59dc805bdc2cc9 Mon Sep 17 00:00:00 2001 From: jasonnnli Date: Sun, 8 Oct 2023 11:35:10 +0800 Subject: [PATCH] refactor: udf return_type --- datafusion-examples/examples/complex_udf.rs | 140 ++++++++++++++++++++ datafusion/expr/src/expr_schema.rs | 14 +- datafusion/expr/src/function.rs | 21 ++- datafusion/expr/src/udf.rs | 5 +- datafusion/physical-expr/src/udf.rs | 12 +- 5 files changed, 187 insertions(+), 5 deletions(-) create mode 100644 datafusion-examples/examples/complex_udf.rs diff --git a/datafusion-examples/examples/complex_udf.rs b/datafusion-examples/examples/complex_udf.rs new file mode 100644 index 000000000000..185524a284f4 --- /dev/null +++ b/datafusion-examples/examples/complex_udf.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::function::ReturnTypeFactory; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionImplementation, ScalarUDF, Signature, +}; +use std::sync::Arc; + +// create local execution context with an in-memory table +fn create_context() -> Result { + use datafusion::arrow::datatypes::{Field, Schema}; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float64, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1, 6.1])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + const UDF_NAME: &str = "take"; + + let ctx = create_context()?; + + // Syntax: + // `take(float32_expr, float64_expr, index)` + // If index eq 0, return float32_expr, which DataType is DataType::Float32; + // If index eq 1, return float64_expr, which DataType is DataType::Float64; + // Else return Err. + let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| { + let take_idx = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + _ => unreachable!(), + }; + match &args[take_idx] { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), + ColumnarValue::Scalar(_) => unimplemented!(), + } + }); + + // Implement a ReturnTypeFactory. + struct ReturnType; + + impl ReturnTypeFactory for ReturnType { + fn infer( + &self, + data_types: &[DataType], + literals: &[(usize, ScalarValue)], + ) -> Result> { + assert_eq!(literals.len(), 1); + let (idx, val) = &literals[0]; + assert_eq!(idx, &2); + + let take_idx = match val { + ScalarValue::Int64(Some(v)) if v < &2 => *v as usize, + _ => unreachable!(), + }; + + Ok(Arc::new(data_types[take_idx].clone())) + } + } + + let signature = Signature::exact( + vec![DataType::Float32, DataType::Float64, DataType::Int64], + Volatility::Immutable, + ); + + let udf = ScalarUDF { + name: UDF_NAME.to_string(), + signature, + return_type: Arc::new(ReturnType {}), + fun, + }; + + ctx.register_udf(udf); + + // SELECT take(a, b, 0) AS take0, take(a, b, 1) AS take1 FROM t; + let df = ctx.table("t").await?; + let take = df.registry().udf(UDF_NAME)?; + let expr0 = take + .call(vec![col("a"), col("b"), lit(0_i64)]) + .alias("take0"); + let expr1 = take + .call(vec![col("a"), col("b"), lit(1_i64)]) + .alias("take1"); + + let df = df.select(vec![expr0, expr1])?; + let schema = df.schema(); + + // Check output schema + assert_eq!(schema.field(0).data_type(), &DataType::Float32); + assert_eq!(schema.field(1).data_type(), &DataType::Float64); + + df.show().await?; + + Ok(()) +} diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5881feece1fc..fe0f49d65e85 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -87,7 +87,19 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + + let mut literals = vec![]; + args.iter().enumerate().for_each(|(i, arg)| { + if let Expr::Literal(scalar_value) = arg { + literals.push((i, scalar_value.clone())); + } + }); + + Ok(fun + .return_type + .infer(&data_types, &literals)? + .as_ref() + .clone()) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { let arg_data_types = args diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3e30a5574be0..1181a80853d5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -21,7 +21,7 @@ use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; use arrow::datatypes::DataType; use datafusion_common::utils::datafusion_strsim; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue}; use std::sync::Arc; use strum::IntoEnumIterator; @@ -36,10 +36,29 @@ use strum::IntoEnumIterator; pub type ScalarFunctionImplementation = Arc Result + Send + Sync>; +/// Factory that returns the functions's return type given the input argument types and constant arguments +pub trait ReturnTypeFactory: Send + Sync { + fn infer( + &self, + data_types: &[DataType], + literals: &[(usize, ScalarValue)], + ) -> Result>; +} + /// Factory that returns the functions's return type given the input argument types pub type ReturnTypeFunction = Arc Result> + Send + Sync>; +impl ReturnTypeFactory for ReturnTypeFunction { + fn infer( + &self, + data_types: &[DataType], + _literals: &[(usize, ScalarValue)], + ) -> Result> { + self(data_types) + } +} + /// Factory that returns an accumulator for the given aggregate, given /// its return datatype. pub type AccumulatorFactoryFunction = diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985..5edbe559177b 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,6 +17,7 @@ //! Udf module contains foundational types that are used to represent UDFs in DataFusion. +use crate::function::ReturnTypeFactory; use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::fmt; use std::fmt::Debug; @@ -31,7 +32,7 @@ pub struct ScalarUDF { /// signature pub signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + pub return_type: Arc, /// actual implementation /// /// The fn param is the wrapped function but be aware that the function will @@ -79,7 +80,7 @@ impl ScalarUDF { Self { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), + return_type: Arc::new(return_type.clone()), fun: fun.clone(), } } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index af1e77cbf566..dd86faa03ed6 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -34,11 +34,21 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; + let mut literals = vec![]; + input_phy_exprs.iter().enumerate().for_each(|(i, expr)| { + if let Some(literal) = expr.as_any().downcast_ref::() + { + literals.push((i, literal.value().clone())); + } + }); + Ok(Arc::new(ScalarFunctionExpr::new( &fun.name, fun.fun.clone(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type + .infer(&input_exprs_types, &literals)? + .as_ref(), None, ))) }