Skip to content

Commit

Permalink
refactor: udf return_type
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonnnli committed Nov 15, 2023
1 parent abb2ae7 commit 4c10ce7
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 5 deletions.
140 changes: 140 additions & 0 deletions datafusion-examples/examples/complex_udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::{
arrow::{
array::{Float32Array, Float64Array},
datatypes::DataType,
record_batch::RecordBatch,
},
logical_expr::Volatility,
};

use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::ScalarValue;
use datafusion_expr::function::ReturnTypeFactory;
use datafusion_expr::{
ColumnarValue, ScalarFunctionImplementation, ScalarUDF, Signature,
};
use std::sync::Arc;

// create local execution context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float64, false),
]));

// define data.
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1, 6.1])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
const UDF_NAME: &str = "take";

let ctx = create_context()?;

// Syntax:
// `take(float32_expr, float64_expr, index)`
// If index eq 0, return float32_expr, which DataType is DataType::Float32;
// If index eq 1, return float64_expr, which DataType is DataType::Float64;
// Else return Err.
let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
_ => unreachable!(),
};
match &args[take_idx] {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())),
ColumnarValue::Scalar(_) => unimplemented!(),
}
});

// Implement a ReturnTypeFactory.
struct ReturnType;

impl ReturnTypeFactory for ReturnType {
fn infer(
&self,
data_types: &[DataType],
literals: &[(usize, ScalarValue)],
) -> Result<Arc<DataType>> {
assert_eq!(literals.len(), 1);
let (idx, val) = &literals[0];
assert_eq!(idx, &2);

let take_idx = match val {
ScalarValue::Int64(Some(v)) if v < &2 => *v as usize,
_ => unreachable!(),
};

Ok(Arc::new(data_types[take_idx].clone()))
}
}

let signature = Signature::exact(
vec![DataType::Float32, DataType::Float64, DataType::Int64],
Volatility::Immutable,
);

let udf = ScalarUDF {
name: UDF_NAME.to_string(),
signature,
return_type: Arc::new(ReturnType {}),
fun,
};

ctx.register_udf(udf);

// SELECT take(a, b, 0) AS take0, take(a, b, 1) AS take1 FROM t;
let df = ctx.table("t").await?;
let take = df.registry().udf(UDF_NAME)?;
let expr0 = take
.call(vec![col("a"), col("b"), lit(0_i64)])
.alias("take0");
let expr1 = take
.call(vec![col("a"), col("b"), lit(1_i64)])
.alias("take1");

let df = df.select(vec![expr0, expr1])?;
let schema = df.schema();

// Check output schema
assert_eq!(schema.field(0).data_type(), &DataType::Float32);
assert_eq!(schema.field(1).data_type(), &DataType::Float64);

df.show().await?;

Ok(())
}
14 changes: 13 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,19 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())

let mut literals = vec![];
args.iter().enumerate().for_each(|(i, arg)| {
if let Expr::Literal(scalar_value) = arg {
literals.push((i, scalar_value.clone()));
}
});

Ok(fun
.return_type
.infer(&data_types, &literals)?
.as_ref()
.clone())
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let arg_data_types = args
Expand Down
21 changes: 20 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
use arrow::datatypes::DataType;
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::Result;
use datafusion_common::{Result, ScalarValue};
use std::sync::Arc;
use strum::IntoEnumIterator;

Expand All @@ -36,10 +36,29 @@ use strum::IntoEnumIterator;
pub type ScalarFunctionImplementation =
Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;

/// Factory that returns the functions's return type given the input argument types and constant arguments
pub trait ReturnTypeFactory: Send + Sync {
fn infer(
&self,
data_types: &[DataType],
literals: &[(usize, ScalarValue)],
) -> Result<Arc<DataType>>;
}

/// Factory that returns the functions's return type given the input argument types
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;

impl ReturnTypeFactory for ReturnTypeFunction {
fn infer(
&self,
data_types: &[DataType],
_literals: &[(usize, ScalarValue)],
) -> Result<Arc<DataType>> {
self(data_types)
}
}

/// Factory that returns an accumulator for the given aggregate, given
/// its return datatype.
pub type AccumulatorFactoryFunction =
Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Udf module contains foundational types that are used to represent UDFs in DataFusion.
use crate::function::ReturnTypeFactory;
use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use std::fmt;
use std::fmt::Debug;
Expand All @@ -31,7 +32,7 @@ pub struct ScalarUDF {
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
pub return_type: Arc<dyn ReturnTypeFactory>,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
Expand Down Expand Up @@ -79,7 +80,7 @@ impl ScalarUDF {
Self {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
return_type: Arc::new(return_type.clone()),
fun: fun.clone(),
}
}
Expand Down
12 changes: 11 additions & 1 deletion datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,21 @@ pub fn create_physical_expr(
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;

let mut literals = vec![];
input_phy_exprs.iter().enumerate().for_each(|(i, expr)| {
if let Some(literal) = expr.as_any().downcast_ref::<crate::expressions::Literal>()
{
literals.push((i, literal.value().clone()));
}
});

Ok(Arc::new(ScalarFunctionExpr::new(
&fun.name,
fun.fun.clone(),
input_phy_exprs.to_vec(),
(fun.return_type)(&input_exprs_types)?.as_ref(),
fun.return_type
.infer(&input_exprs_types, &literals)?
.as_ref(),
None,
)))
}

0 comments on commit 4c10ce7

Please sign in to comment.