From c6d640b0a77a950994f0cd89c6e23ac64c0ef76c Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Thu, 1 Aug 2024 15:23:51 +0800 Subject: [PATCH] Generate GroupByHash output in multiple RecordBatches --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- .../src/aggregates/group_values/mod.rs | 24 ++++ .../src/aggregates/group_values/row.rs | 37 ++++++ .../physical-plan/src/aggregates/row_hash.rs | 120 ++++++++++++------ 4 files changed, 142 insertions(+), 41 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 6f286c9aeba1..69fa1af1bc5d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -152,7 +152,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str assert!(collected_running.len() > 2); // Running should produce more chunk than the usual AggregateExec. // Otherwise it means that we cannot generate result in running mode. - assert!(collected_running.len() > collected_usual.len()); + // assert!(collected_running.len() > collected_usual.len()); // compare let usual_formatted = pretty_format_batches(&collected_usual).unwrap().to_string(); let running_formatted = pretty_format_batches(&collected_running) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index be7ac934d7bc..bf465ecee5f6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -50,6 +50,30 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + /// Emits all group values based on batch_size + fn emit_all_with_batch_size( + &mut self, + batch_size: usize, + ) -> Result>> { + let ceil = (self.len() + batch_size - 1) / batch_size; + let mut outputs = Vec::with_capacity(ceil); + let mut remaining = self.len(); + + while remaining > 0 { + if remaining > batch_size { + let emit_to = EmitTo::First(batch_size); + outputs.push(self.emit(emit_to)?); + remaining -= batch_size; + } else { + let emit_to = EmitTo::All; + outputs.push(self.emit(emit_to)?); + remaining = 0; + } + } + + Ok(outputs) + } + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, batch: &RecordBatch); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dc948e28bb2d..76c0feec5a14 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -27,6 +27,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use itertools::Itertools; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { @@ -236,6 +237,42 @@ impl GroupValues for GroupValuesRows { Ok(output) } + fn emit_all_with_batch_size( + &mut self, + batch_size: usize, + ) -> Result>> { + let mut group_values = self + .group_values + .take() + .expect("Can not emit from empty rows"); + + let ceil = (group_values.num_rows() + batch_size - 1) / batch_size; + let mut outputs = Vec::with_capacity(ceil); + + for chunk in group_values.iter().chunks(batch_size).into_iter() { + let groups_rows = chunk; + let mut output = self.row_converter.convert_rows(groups_rows)?; + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + outputs.push(output); + } + + group_values.clear(); + self.group_values = Some(group_values); + + Ok(outputs) + } + fn clear_shrink(&mut self, batch: &RecordBatch) { let count = batch.num_rows(); self.group_values = self.group_values.take().map(|mut rows| { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1b84befb0269..be980de8b498 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -17,6 +17,7 @@ //! Hash aggregation +use std::collections::VecDeque; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; @@ -61,7 +62,7 @@ pub(crate) enum ExecutionState { ReadingInput, /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks - ProducingOutput(RecordBatch), + ProducingOutput(VecDeque), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -553,7 +554,7 @@ impl Stream for GroupedHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); loop { - match &self.exec_state { + match &mut self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { // new batch to aggregate @@ -583,8 +584,9 @@ impl Stream for GroupedHashAggregateStream { } if let Some(to_emit) = self.group_ordering.emit_to() { - let batch = extract_ok!(self.emit(to_emit, false)); - self.exec_state = ExecutionState::ProducingOutput(batch); + let batches = extract_ok!(self.emit(to_emit, false)); + self.exec_state = + ExecutionState::ProducingOutput(batches); timer.done(); // make sure the exec_state just set is not overwritten below break 'reading_input; @@ -627,29 +629,20 @@ impl Stream for GroupedHashAggregateStream { } } - ExecutionState::ProducingOutput(batch) => { - // slice off a part of the batch, if needed - let output_batch; - let size = self.batch_size; - (self.exec_state, output_batch) = if batch.num_rows() <= size { - ( - if self.input_done { - ExecutionState::Done - } else if self.should_skip_aggregation() { - ExecutionState::SkippingAggregation - } else { - ExecutionState::ReadingInput - }, - batch.clone(), - ) - } else { - // output first batch_size rows - let size = self.batch_size; - let num_remaining = batch.num_rows() - size; - let remaining = batch.slice(size, num_remaining); - let output = batch.slice(0, size); - (ExecutionState::ProducingOutput(remaining), output) - }; + ExecutionState::ProducingOutput(batches) => { + assert!(!batches.is_empty()); + let output_batch = batches.pop_front().expect("RecordBatch"); + + if batches.is_empty() { + self.exec_state = if self.input_done { + ExecutionState::Done + } else if self.should_skip_aggregation() { + ExecutionState::SkippingAggregation + } else { + ExecutionState::ReadingInput + }; + } + return Poll::Ready(Some(Ok( output_batch.record_output(&self.baseline_metrics) ))); @@ -777,14 +770,55 @@ impl GroupedHashAggregateStream { /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { let schema = if spilling { Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; if self.group_values.is_empty() { - return Ok(RecordBatch::new_empty(schema)); + return Ok(VecDeque::from([RecordBatch::new_empty(schema)])); + } + + if matches!(emit_to, EmitTo::All) && !spilling { + let outputs = self + .group_values + .emit_all_with_batch_size(self.batch_size)?; + + let mut batches = VecDeque::with_capacity(outputs.len()); + for mut output in outputs { + let num_rows = output[0].len(); + // let batch_emit_to = EmitTo::First(num_rows); + let batch_emit_to = if num_rows == self.batch_size { + EmitTo::First(self.batch_size) + } else { + EmitTo::All + }; + + for acc in self.accumulators.iter_mut() { + match self.mode { + AggregateMode::Partial => { + output.extend(acc.state(batch_emit_to)?) + } + _ if spilling => { + // If spilling, output partial state because the spilled data will be + // merged and re-evaluated later. + output.extend(acc.state(batch_emit_to)?) + } + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => { + output.push(acc.evaluate(batch_emit_to)?) + } + } + } + let batch = RecordBatch::try_new(Arc::clone(&schema), output)?; + batches.push_back(batch); + } + + let _ = self.update_memory_reservation(); + return Ok(batches); } let mut output = self.group_values.emit(emit_to)?; @@ -812,7 +846,7 @@ impl GroupedHashAggregateStream { // over the target memory size after emission, we can emit again rather than returning Err. let _ = self.update_memory_reservation(); let batch = RecordBatch::try_new(schema, output)?; - Ok(batch) + Ok(VecDeque::from([batch])) } /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly @@ -838,7 +872,9 @@ impl GroupedHashAggregateStream { /// Emit all rows, sort them, and store them on disk. fn spill(&mut self) -> Result<()> { - let emit = self.emit(EmitTo::All, true)?; + let mut batches = self.emit(EmitTo::All, true)?; + assert_eq!(batches.len(), 1); + let emit = batches.pop_front().expect("RecordBatch"); let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; @@ -881,8 +917,8 @@ impl GroupedHashAggregateStream { && self.update_memory_reservation().is_err() { let n = self.group_values.len() / self.batch_size * self.batch_size; - let batch = self.emit(EmitTo::First(n), false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + let batches = self.emit(EmitTo::First(n), false)?; + self.exec_state = ExecutionState::ProducingOutput(batches); } Ok(()) } @@ -892,18 +928,22 @@ impl GroupedHashAggregateStream { /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. fn update_merged_stream(&mut self) -> Result<()> { - let batch = self.emit(EmitTo::All, true)?; + let batches = self.emit(EmitTo::All, true)?; + assert!(!batches.is_empty()); + let schema = batches[0].schema(); // clear up memory for streaming_merge self.clear_all(); self.update_memory_reservation()?; let mut streams: Vec = vec![]; let expr = self.spill_state.spill_expr.clone(); - let schema = batch.schema(); + // TODO No need to collect + let sorted = batches + .into_iter() + .map(|batch| sort_batch(&batch, &expr, None)) + .collect::>(); streams.push(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), - futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) - })), + futures::stream::iter(sorted), ))); for spill in self.spill_state.spills.drain(..) { let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?; @@ -940,8 +980,8 @@ impl GroupedHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { - let batch = self.emit(EmitTo::All, false)?; - ExecutionState::ProducingOutput(batch) + let batches = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batches) } else { // If spill files exist, stream-merge them. self.update_merged_stream()?;