Skip to content

Commit

Permalink
feat: support grouping aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLi-cn committed Apr 24, 2024
1 parent 70db5ea commit c716130
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 29 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<Float64Type>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ impl GroupsAccumulator for TestGroupsAccumulator {
_group_indices: &[usize],
_opt_filter: Option<&arrow_array::BooleanArray>,
_total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
Ok(())
}
Expand Down
10 changes: 6 additions & 4 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,12 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable),
AggregateFunction::ApproxDistinct
| AggregateFunction::Grouping
| AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Count | AggregateFunction::Grouping => {
Signature::variadic_any(Volatility::Immutable)
}
AggregateFunction::ApproxDistinct | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ pub trait GroupsAccumulator: Send {
/// * `total_num_groups`: the number of groups (the largest
/// group_index is thus `total_num_groups - 1`).
///
/// * `grouping_set`: An indicator of whether the columns in the
/// GroupingSet is null, typically used in GROUPING aggregate function.
///
/// Note that subsequent calls to update_batch may have larger
/// total_num_groups as new groups are seen.
fn update_batch(
Expand All @@ -100,6 +103,7 @@ pub trait GroupsAccumulator: Send {
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
grouping_set: &[bool],
) -> Result<()>;

/// Returns the final aggregate value for each group as a single
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ pub fn coerce_types(
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::Grouping => Ok(input_types.to_vec()),
AggregateFunction::StringAgg => {
if !is_string_agg_supported_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ where
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
Expand Down
8 changes: 3 additions & 5 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,9 @@ pub fn create_aggregate_expr(
input_phy_exprs[0].clone(),
name,
)),
(AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new(
input_phy_exprs[0].clone(),
name,
data_type,
)),
(AggregateFunction::Grouping, _) => {
Arc::new(expressions::Grouping::new(input_phy_exprs, name, data_type))
}
(AggregateFunction::BitAnd, _) => Arc::new(expressions::BitAnd::new(
input_phy_exprs[0].clone(),
name,
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ impl GroupsAccumulator for CountGroupsAccumulator {
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = &values[0];
Expand Down
207 changes: 194 additions & 13 deletions datafusion/physical-expr/src/aggregate/grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::accumulate_indices;
use crate::aggregate::utils::down_cast_any_ref;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::Accumulator;
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{Array, ArrayRef, Int32Array, PrimitiveArray};
use datafusion_common::{not_impl_err, DataFusionError, Result};
use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator};
use datafusion_physical_expr_common::expressions::column::Column;

use crate::expressions::format_state_name;

Expand All @@ -36,23 +42,34 @@ pub struct Grouping {
name: String,
data_type: DataType,
nullable: bool,
expr: Arc<dyn PhysicalExpr>,
exprs: Vec<Arc<dyn PhysicalExpr>>,
}

impl Grouping {
/// Create a new GROUPING aggregate function.
pub fn new(
expr: Arc<dyn PhysicalExpr>,
exprs: Vec<Arc<dyn PhysicalExpr>>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
Self {
name: name.into(),
expr,
exprs,
data_type,
nullable: true,
}
}

/// Create a new GroupingGroupsAccumulator
pub fn create_grouping_groups_accumulator(
&self,
group_by_exprs: &[(Arc<dyn PhysicalExpr>, String)],
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GroupingGroupsAccumulator::new(
&self.exprs,
group_by_exprs,
)?))
}
}

impl AggregateExpr for Grouping {
Expand All @@ -65,6 +82,12 @@ impl AggregateExpr for Grouping {
Ok(Field::new(&self.name, DataType::Int32, self.nullable))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
not_impl_err!(
"physical plan is not yet implemented for GROUPING aggregate function"
)
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
format_state_name(&self.name, "grouping"),
Expand All @@ -74,13 +97,7 @@ impl AggregateExpr for Grouping {
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
not_impl_err!(
"physical plan is not yet implemented for GROUPING aggregate function"
)
self.exprs.clone()
}

fn name(&self) -> &str {
Expand All @@ -96,8 +113,172 @@ impl PartialEq<dyn Any> for Grouping {
self.name == x.name
&& self.data_type == x.data_type
&& self.nullable == x.nullable
&& self.expr.eq(&x.expr)
&& self.exprs.len() == x.exprs.len()
&& self
.exprs
.iter()
.zip(x.exprs.iter())
.all(|(expr1, expr2)| expr1.eq(expr2))
})
.unwrap_or(false)
}
}

#[derive(Debug)]
struct GroupingGroupsAccumulator {
/// Grouping columns' indices in grouping set
indices: Vec<usize>,

/// Mask per group.
///
/// Note this is an i32 and not a u32 (or usize) because the
/// output type of grouping is `DataType::Int32`. Thus by using `i32`
/// for the grouping, the output [`Int32Array`] can be created
/// without copy.
masks: Vec<i32>,
}

impl GroupingGroupsAccumulator {
pub fn new(
grouping_exprs: &[Arc<dyn PhysicalExpr>],
group_by_exprs: &[(Arc<dyn PhysicalExpr>, String)],
) -> Result<Self> {
macro_rules! downcast_column {
($EXPR:expr) => {{
if let Some(column) = $EXPR.as_any().downcast_ref::<Column>() {
column
} else {
return Err(DataFusionError::Execution(
"Grouping only supports grouping set which only contains Column Expr".to_string(),
));
}
}}
}

// collect column indices of group_by_exprs, only Column Expr
let mut group_by_column_indices = Vec::with_capacity(group_by_exprs.len());
for (group_by_expr, _) in group_by_exprs.iter() {
let column = downcast_column!(group_by_expr);
group_by_column_indices.push(column.index());
}

// collect grouping_exprs' indices in group_by_exprs list, eg:
// SQL: SELECT c1, c2, grouping(c2, c1) FROM t GROUP BY ROLLUP(c1, c2);
// group_by_exprs: [c1, c2]
// grouping_exprs: [c2, c1]
// indices: [1, 0]
let mut indices = Vec::with_capacity(grouping_exprs.len());
for grouping_expr in grouping_exprs {
let column = downcast_column!(grouping_expr);
indices.push(find_grouping_column_index(
&group_by_column_indices,
column.index(),
)?);
}

Ok(Self {
indices,
masks: vec![],
})
}
}

fn find_grouping_column_index(
group_by_column_indices: &[usize],
grouping_column_index: usize,
) -> Result<usize> {
for (i, group_by_column_index) in group_by_column_indices.iter().enumerate() {
if grouping_column_index == *group_by_column_index {
return Ok(i);
}
}
Err(DataFusionError::Execution(
"Not found grouping column in group by columns".to_string(),
))
}

fn compute_mask(indices: &[usize], grouping_set: &[bool]) -> i32 {
let mut mask = 0;
for (i, index) in indices.iter().rev().enumerate() {
if grouping_set[*index] {
mask |= 1 << i;
}
}
mask
}

impl GroupsAccumulator for GroupingGroupsAccumulator {
fn update_batch(
&mut self,
_values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
grouping_set: &[bool],
) -> Result<()> {
self.masks.resize(total_num_groups, 0);
accumulate_indices(group_indices, None, opt_filter, |group_index| {
self.masks[group_index] = compute_mask(&self.indices, grouping_set);
});
Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let masks = emit_to.take_needed(&mut self.masks);

// Mask is always non null (null inputs just don't contribute to the overall values)
let nulls = None;
let array = PrimitiveArray::<Int32Type>::new(masks.into(), nulls);

Ok(Arc::new(array))
}

// return arrays for masks
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let masks = emit_to.take_needed(&mut self.masks);
let masks: PrimitiveArray<Int32Type> = Int32Array::from(masks); // zero copy, no nulls
Ok(vec![Arc::new(masks) as ArrayRef])
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let masks = values[0].as_primitive::<Int32Type>();

// intermediate masks are always created as non null
assert_eq!(masks.null_count(), 0);
let masks = masks.values();

self.masks.resize(total_num_groups, 0);
match opt_filter {
Some(filter) => filter
.iter()
.zip(group_indices.iter())
.zip(masks.iter())
.for_each(|((filter_value, &group_index), mask)| {
if let Some(true) = filter_value {
self.masks[group_index] = *mask;
}
}),
None => {
group_indices
.iter()
.zip(masks.iter())
.for_each(|(&group_index, mask)| {
self.masks[group_index] = *mask;
})
}
}

Ok(())
}

fn size(&self) -> usize {
self.masks.capacity() * std::mem::size_of::<usize>()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
self.invoke_per_accumulator(
values,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ where
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_boolean();
Expand Down Expand Up @@ -129,7 +130,7 @@ where
total_num_groups: usize,
) -> Result<()> {
// update / merge are the same
self.update_batch(values, group_indices, opt_filter, total_num_groups)
self.update_batch(values, group_indices, opt_filter, total_num_groups, &[])
}

fn size(&self) -> usize {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ where
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
_grouping_set: &[bool],
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
Expand Down Expand Up @@ -131,7 +132,7 @@ where
total_num_groups: usize,
) -> Result<()> {
// update / merge are the same
self.update_batch(values, group_indices, opt_filter, total_num_groups)
self.update_batch(values, group_indices, opt_filter, total_num_groups, &[])
}

fn size(&self) -> usize {
Expand Down
Loading

0 comments on commit c716130

Please sign in to comment.