diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 98787d740c20..d81a26243b67 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -36,7 +36,7 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use arrow_schema::SortOptions; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; @@ -490,6 +490,11 @@ impl GroupedHashAggregateStream { .collect::>()?; let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; + + // Build partial aggregate schema for spills + let partial_agg_schema = + build_partial_agg_schema(&group_schema, &aggregate_exprs)?; + let spill_expr = group_schema .fields .into_iter() @@ -522,7 +527,7 @@ impl GroupedHashAggregateStream { let spill_state = SpillState { spills: vec![], spill_expr, - spill_schema: Arc::clone(&agg_schema), + spill_schema: partial_agg_schema, is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), @@ -802,6 +807,45 @@ impl RecordBatchStream for GroupedHashAggregateStream { } } +// fix https://github.com/apache/datafusion/issues/13949 +/// Builds a **partial aggregation** schema by combining the group columns and +/// the accumulator state columns produced by each aggregate expression. +/// +/// # Why Partial Aggregation Schema Is Needed +/// +/// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate +/// operator produces *intermediate* states (e.g., partial sums, counts) rather +/// than final scalar values. These extra columns do **not** exist in the original +/// input schema (which may be something like `[colA, colB, ...]`). Instead, +/// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`). +/// +/// Therefore, when we spill these intermediate states or pass them to another +/// aggregation operator, we must use a schema that includes both the group +/// columns **and** the partial-state columns. Otherwise, using the original input +/// schema to read partial states will result in a column-count mismatch error. +/// +/// This helper function constructs such a schema: +/// `[group_col_1, group_col_2, ..., state_col_1, state_col_2, ...]` +/// so that partial aggregation data can be handled consistently. +fn build_partial_agg_schema( + group_schema: &SchemaRef, + aggregate_exprs: &[Arc], +) -> Result { + let fields = group_schema.fields().clone(); + // convert fields to Vec> + let mut fields = fields.iter().cloned().collect::>(); + for expr in aggregate_exprs { + let state_fields = expr.state_fields(); + fields.extend( + state_fields + .into_iter() + .flat_map(|inner_vec| inner_vec.into_iter()) // Flatten the Vec> to Vec + .map(Arc::new), // Wrap each Field in Arc + ); + } + Ok(Arc::new(Schema::new(fields))) +} + impl GroupedHashAggregateStream { /// Perform group-by aggregation for the given [`RecordBatch`]. fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { @@ -966,7 +1010,6 @@ impl GroupedHashAggregateStream { assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. - self.spill_state.spill_schema = batch.schema(); self.spill()?; self.clear_shrink(batch); }