Skip to content

Commit

Permalink
Refactor spill handling in GroupedHashAggregateStream to use partial …
Browse files Browse the repository at this point in the history
…aggregate schema
  • Loading branch information
kosiew committed Jan 3, 2025
1 parent 9b5995f commit da2b11a
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
use crate::{RecordBatchStream, SendableRecordBatchStream};

use arrow::array::*;
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{Schema, SchemaRef};
use arrow_schema::SortOptions;
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_execution::disk_manager::RefCountedTempFile;
Expand Down Expand Up @@ -490,6 +490,11 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;

// Build partial aggregate schema for spills
let partial_agg_schema =
build_partial_agg_schema(&group_schema, &aggregate_exprs)?;

let spill_expr = group_schema
.fields
.into_iter()
Expand Down Expand Up @@ -522,7 +527,7 @@ impl GroupedHashAggregateStream {
let spill_state = SpillState {
spills: vec![],
spill_expr,
spill_schema: Arc::clone(&agg_schema),
spill_schema: partial_agg_schema,
is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
Expand Down Expand Up @@ -802,6 +807,45 @@ impl RecordBatchStream for GroupedHashAggregateStream {
}
}

// fix https://github.com/apache/datafusion/issues/13949
/// Builds a **partial aggregation** schema by combining the group columns and
/// the accumulator state columns produced by each aggregate expression.
///
/// # Why Partial Aggregation Schema Is Needed
///
/// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
/// operator produces *intermediate* states (e.g., partial sums, counts) rather
/// than final scalar values. These extra columns do **not** exist in the original
/// input schema (which may be something like `[colA, colB, ...]`). Instead,
/// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
///
/// Therefore, when we spill these intermediate states or pass them to another
/// aggregation operator, we must use a schema that includes both the group
/// columns **and** the partial-state columns. Otherwise, using the original input
/// schema to read partial states will result in a column-count mismatch error.
///
/// This helper function constructs such a schema:
/// `[group_col_1, group_col_2, ..., state_col_1, state_col_2, ...]`
/// so that partial aggregation data can be handled consistently.
fn build_partial_agg_schema(
group_schema: &SchemaRef,
aggregate_exprs: &[Arc<AggregateFunctionExpr>],
) -> Result<SchemaRef> {
let fields = group_schema.fields().clone();
// convert fields to Vec<Arc<Field>>
let mut fields = fields.iter().cloned().collect::<Vec<_>>();
for expr in aggregate_exprs {
let state_fields = expr.state_fields();
fields.extend(
state_fields
.into_iter()
.flat_map(|inner_vec| inner_vec.into_iter()) // Flatten the Vec<Vec<Field>> to Vec<Field>
.map(Arc::new), // Wrap each Field in Arc
);
}
Ok(Arc::new(Schema::new(fields)))
}

impl GroupedHashAggregateStream {
/// Perform group-by aggregation for the given [`RecordBatch`].
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
Expand Down Expand Up @@ -966,7 +1010,6 @@ impl GroupedHashAggregateStream {
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
self.spill()?;
self.clear_shrink(batch);
}
Expand Down

0 comments on commit da2b11a

Please sign in to comment.