Skip to content

Commit

Permalink
decode dictionary statistical_batch before evaluating
Browse files Browse the repository at this point in the history
  • Loading branch information
kosiew committed Jan 9, 2025
1 parent 10d78d4 commit 7c0cf6b
Showing 1 changed file with 48 additions and 12 deletions.
60 changes: 48 additions & 12 deletions datafusion/physical-optimizer/src/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,15 @@
// specific language governing permissions and limitations
// under the License.

//! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers"
//! based on statistics (e.g. Parquet Row Groups)
//!
//! [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
use std::collections::HashSet;
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::compute::{cast, CastOptions};
use arrow::{
array::{new_null_array, ArrayRef, BooleanArray},
array::{new_null_array, ArrayRef, AsArray, BooleanArray},
datatypes::{DataType, Field, Schema, SchemaRef},
record_batch::{RecordBatch, RecordBatchOptions},
};
use log::trace;
use std::collections::HashSet;
use std::sync::Arc;

use datafusion_common::error::{DataFusionError, Result};
use datafusion_common::tree_node::TransformedResult;
Expand Down Expand Up @@ -505,6 +500,21 @@ impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
}
}

fn decode_dictionary_to_decimal(
array: &ArrayRef,
precision: u8,
scale: u8,
) -> arrow::error::Result<ArrayRef> {
// e.g. Decimal128(4, 1), or whatever your stats require
let target_type = DataType::Decimal128(
(precision as usize).try_into().unwrap(),
(scale as usize).try_into().unwrap(),
);
// The CastOptions can specify whether to allow loss of precision, etc.
let casted = cast(array.as_ref(), &target_type)?;
Ok(casted)
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand Down Expand Up @@ -622,10 +632,36 @@ impl PruningPredicate {
// appropriate statistics columns for the min/max predicate
let statistics_batch =
build_statistics_record_batch(statistics, &self.required_columns)?;
println!("==> Statistics batch columns: {:#?}", statistics_batch);

println!("==> Statistics batch columns: {:#?}", statistics_batch);
// Construct a new, decoded record batch if you detect dictionary-of-decimal columns
let decoded_columns = statistics_batch
.columns()
.iter()
.zip(statistics_batch.schema().fields())
.map(|(arr, field)| {
if let DataType::Dictionary(_, inner_ty) = field.data_type() {
// if it's decimal
if let DataType::Decimal128(precision, scale) = &**inner_ty {
return decode_dictionary_to_decimal(
arr,
*precision as u8,
*scale as u8,
);
}
}
// fallback: no decode
Ok(Arc::clone(arr))
})
.collect::<Result<Vec<ArrayRef>, arrow::error::ArrowError>>()?;

// Build a new RecordBatch with these columns
let decoded_stats_batch = RecordBatch::try_new(
Arc::clone(&statistics_batch.schema()),
decoded_columns,
)?;
// Evaluate the pruning predicate on that record batch and append any results to the builder
let eval_result = self.predicate_expr.evaluate(&statistics_batch)?;
let eval_result = self.predicate_expr.evaluate(&decoded_stats_batch)?;
println!(
"==> Evaluating expression: {:?} => {:?}",
self.predicate_expr, eval_result
Expand Down Expand Up @@ -981,7 +1017,7 @@ fn build_statistics_record_batch<S: PruningStatistics>(

// cast statistics array to required data type (e.g. parquet
// provides timestamp statistics as "Int64")
let array = arrow::compute::cast(&array, data_type)?;
let array = cast(&array, data_type)?;

fields.push(stat_field.clone());
arrays.push(array);
Expand Down

0 comments on commit 7c0cf6b

Please sign in to comment.