Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/PR_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 17, 2024
2 parents 9e95d64 + 32b63ff commit 9002dea
Show file tree
Hide file tree
Showing 27 changed files with 1,279 additions and 1,086 deletions.
15 changes: 5 additions & 10 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
11 changes: 3 additions & 8 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/physical_optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ impl PhysicalOptimizer {
// Remove the ancillary output requirement operator since we are done with the planning
// phase.
Arc::new(OutputRequirements::new_remove_mode()),
// The PipelineChecker rule will reject non-runnable query plans that use
// pipeline-breaking operators on infinite input(s). The rule generates a
// diagnostic error message when this happens. It makes no changes to the
// given query plan; i.e. it only acts as a final gatekeeping rule.
Arc::new(PipelineChecker::new()),
// The aggregation limiter will try to find situations where the accumulator count
// is not tied to the cardinality, i.e. when the output of the aggregation is passed
// into an `order by max(x) limit y`. In this case it will copy the limit value down
Expand All @@ -129,6 +124,11 @@ impl PhysicalOptimizer {
// are not present, the load of executors such as join or union will be
// reduced by narrowing their input tables.
Arc::new(ProjectionPushdown::new()),
// The PipelineChecker rule will reject non-runnable query plans that use
// pipeline-breaking operators on infinite input(s). The rule generates a
// diagnostic error message when this happens. It makes no changes to the
// given query plan; i.e. it only acts as a final gatekeeping rule.
Arc::new(PipelineChecker::new()),
];

Self::with_rules(rules)
Expand Down
34 changes: 3 additions & 31 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
create_window_expr, BoundedWindowAggExec, WindowAggExec,
create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec,
};
use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted};
use datafusion::physical_plan::{collect, InputOrderMode};
Expand All @@ -40,7 +39,6 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
};

let extended_schema =
schema_add_window_fields(&args, &schema, &window_fn, fn_name)?;
schema_add_window_field(&args, &schema, &window_fn, fn_name)?;

let window_expr = create_window_expr(
&window_fn,
Expand Down Expand Up @@ -683,7 +681,7 @@ async fn run_window_test(
exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _;
}

let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?;
let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?;

let usual_window_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
Expand Down Expand Up @@ -754,32 +752,6 @@ async fn run_window_test(
Ok(())
}

// The planner has fully updated schema before calling the `create_window_expr`
// Replicate the same for this test
fn schema_add_window_fields(
args: &[Arc<dyn PhysicalExpr>],
schema: &Arc<Schema>,
window_fn: &WindowFunctionDefinition,
fn_name: &str,
) -> Result<Arc<Schema>> {
let data_types = args
.iter()
.map(|e| e.clone().as_ref().data_type(schema))
.collect::<Result<Vec<_>>>()?;
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
fn_name,
window_expr_return_type,
true,
)]);
Ok(Arc::new(Schema::new(window_fields)))
}

/// Return randomly sized record batches with:
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
/// one random int32 column x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
panic!("accumulator shouldn't invoke");
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
8 changes: 2 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::expr::{
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
Expand Down Expand Up @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
Expand Down
51 changes: 36 additions & 15 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand All @@ -41,11 +41,14 @@ pub type ReturnTypeFunction =
/// [`AccumulatorArgs`] contains information about how an aggregate
/// function was called, including the types of its arguments and any optional
/// ordering expressions.
#[derive(Debug)]
pub struct AccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,

/// The schema of the input arguments
pub schema: &'a Schema,

/// Whether to ignore nulls.
///
/// SQL allows the user to specify `IGNORE NULLS`, for example:
Expand All @@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> {
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// Whether the aggregate function is distinct.
///
/// ```sql
/// SELECT COUNT(DISTINCT column1) FROM t;
/// ```
pub is_distinct: bool,

/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// The number of arguments the aggregate function takes.
pub args_num: usize,
}

impl<'a> AccumulatorArgs<'a> {
pub fn new(
data_type: &'a DataType,
schema: &'a Schema,
ignore_nulls: bool,
sort_exprs: &'a [Expr],
) -> Self {
Self {
data_type,
schema,
ignore_nulls,
sort_exprs,
}
}
/// [`StateFieldsArgs`] contains information about the fields that an
/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`].
///
/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields
pub struct StateFieldsArgs<'a> {
/// The name of the aggregate function.
pub name: &'a str,

/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// The return type of the aggregate function.
pub return_type: &'a DataType,

/// The ordering fields of the aggregate function.
pub ordering_fields: &'a [Field],

/// Whether the aggregate function is distinct.
pub is_distinct: bool,
}

/// Factory that returns an accumulator for the given aggregate function.
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,16 @@ pub enum Distinct {
On(DistinctOn),
}

impl Distinct {
/// return a reference to the nodes input
pub fn input(&self) -> &Arc<LogicalPlan> {
match self {
Distinct::All(input) => input,
Distinct::On(DistinctOn { input, .. }) => input,
}
}
}

/// Removes duplicate rows from the input
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct DistinctOn {
Expand Down
57 changes: 33 additions & 24 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand Down Expand Up @@ -177,18 +179,13 @@ impl AggregateUDF {
/// for more details.
///
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
pub fn groups_accumulator_supported(&self) -> bool {
self.inner.groups_accumulator_supported()
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}

/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
Expand Down Expand Up @@ -232,7 +229,7 @@ where
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
/// # use arrow::datatypes::Schema;
/// # use arrow::datatypes::Field;
/// #[derive(Debug, Clone)]
Expand Down Expand Up @@ -261,9 +258,9 @@ where
/// }
/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
/// Ok(vec![
/// Field::new("value", value_type, true),
/// Field::new("value", args.return_type.clone(), true),
/// Field::new("ordering", DataType::UInt32, true)
/// ])
/// }
Expand Down Expand Up @@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility function
/// to generate a unique name.
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let fields = vec![Field::new(
format_state_name(args.name, "value"),
args.return_type.clone(),
true,
)];

Ok(value_fields.into_iter().chain(ordering_fields).collect())
Ok(fields
.into_iter()
.chain(args.ordering_fields.to_vec())
.collect())
}

/// If the aggregate expression has a specialized
Expand All @@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// `Self::accumulator` for certain queries, such as when this aggregate is
/// used as a window function or when there no GROUP BY columns in the
/// query.
fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}

Expand Down Expand Up @@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}

/// Returns the reverse expression of the aggregate function.
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
}

pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
NotSupported,
/// The expression is different from the original expression
Reversed(Arc<dyn AggregateUDFImpl>),
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
Loading

0 comments on commit 9002dea

Please sign in to comment.