Skip to content

Commit

Permalink
RecordBatch normalization (flattening) (#6758)
Browse files Browse the repository at this point in the history
* Added set up for the example of flattening from pyarrow.

* Logic for recursive normalizer with a base normalize function, based on pola-rs.

* Added recursive normalize function for `Schema`, and started building iterative function for `RecordBatch`. Not sure which one is better currently.

* Built out a bit more of the iterative normalize.

* Fixed normalize function for `RecordBatch`. Adjusted test case to match the example from PyArrow.

* Added tests for `Schema` normalization. Partial tests for `RecordBatch`.

* Removed stray comments.

* Commenting out exclamation field.

* Fixed test for `RecordBatch`.

* Formatting.

* Additional documentation for `normalize` functions. Switched `Schema` normalization to iterative approach.

* Forgot to push to the columns in the else case.

* Adjusted the documentation to include the parameters.

* Formatting.

* Edited examples to not be ran as tests.

* Adjusted based on some of the suggestions. Simplified the matching and if statements, simplified the VecDeque fields.

* Additional test cases for List and FixedSizeList in Schema.

* Additional test cases for deeply nested normalization.

* Suggestions from Jefffrey on the descriptions and stack initialization.

* Forgot parenthesis.

---------

Co-authored-by: nglime <[email protected]>
  • Loading branch information
ngli-me and nglime authored Jan 24, 2025
1 parent 7bb96c5 commit 001239d
Show file tree
Hide file tree
Showing 2 changed files with 768 additions and 4 deletions.
283 changes: 280 additions & 3 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
//! A two-dimensional batch of column-oriented data with a defined
//! [schema](arrow_schema::Schema).
use crate::cast::AsArray;
use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
use std::ops::Index;
use std::sync::Arc;

Expand Down Expand Up @@ -394,6 +395,108 @@ impl RecordBatch {
)
}

/// Normalize a semi-structured [`RecordBatch`] into a flat table.
///
/// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
/// (unlimited if `None`).
///
/// e.g. given a [`RecordBatch`] with schema:
///
/// ```text
/// "foo": StructArray<"bar": Utf8>
/// ```
///
/// A separator of `"."` would generate a batch with the schema:
///
/// ```text
/// "foo.bar": Utf8
/// ```
///
/// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
/// it will be treated as unlimited.
///
/// # Example
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
/// # use arrow_schema::{DataType, Field, Fields, Schema};
/// #
/// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
/// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
///
/// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
/// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
///
/// let a = Arc::new(StructArray::from(vec![
/// (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
/// (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
/// ]));
///
/// let schema = Schema::new(vec![
/// Field::new(
/// "a",
/// DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
/// false,
/// )
/// ]);
///
/// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
/// .expect("valid conversion")
/// .normalize(".", None)
/// .expect("valid normalization");
///
/// let expected = RecordBatch::try_from_iter_with_nullable(vec![
/// ("a.animals", animals.clone(), true),
/// ("a.n_legs", n_legs.clone(), true),
/// ])
/// .expect("valid conversion");
///
/// assert_eq!(expected, normalized);
/// ```
pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
let max_level = match max_level.unwrap_or(usize::MAX) {
0 => usize::MAX,
val => val,
};
let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
.columns
.iter()
.zip(self.schema.fields())
.rev()
.map(|(c, f)| {
let name_vec: Vec<&str> = vec![f.name()];
(0, c, name_vec, f)
})
.collect();
let mut columns: Vec<ArrayRef> = Vec::new();
let mut fields: Vec<FieldRef> = Vec::new();

while let Some((depth, c, name, field_ref)) = stack.pop() {
match field_ref.data_type() {
DataType::Struct(ff) if depth < max_level => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
let mut name = name.clone();
name.push(separator);
name.push(fff.name());
stack.push((depth + 1, cff, name, fff))
}
}
_ => {
let updated_field = Field::new(
name.concat(),
field_ref.data_type().clone(),
field_ref.is_nullable(),
);
columns.push(c.clone());
fields.push(Arc::new(updated_field));
}
}
}
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
}

/// Returns the number of columns in the record batch.
///
/// # Example
Expand Down Expand Up @@ -768,15 +871,14 @@ where

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use super::*;
use crate::{
BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
};
use arrow_buffer::{Buffer, ToByteSlice};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::Fields;
use std::collections::HashMap;

#[test]
fn create_record_batch() {
Expand Down Expand Up @@ -1197,6 +1299,181 @@ mod tests {
assert_ne!(batch1, batch2);
}

#[test]
fn normalize_simple() {
let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));

let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let a = Arc::new(StructArray::from(vec![
(animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
(n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
(year_field.clone(), Arc::new(year.clone()) as ArrayRef),
]));

let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized =
RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
.expect("valid conversion")
.normalize(".", Some(0))
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
("a.n_legs", n_legs.clone(), true),
("a.year", year.clone(), true),
("month", month.clone(), true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);

// check 0 and None have the same effect
let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion")
.normalize(".", None)
.expect("valid normalization");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_nested() {
// Initialize schema
let a = Arc::new(Field::new("a", DataType::Int64, true));
let b = Arc::new(Field::new("b", DataType::Int64, false));
let c = Arc::new(Field::new("c", DataType::Int64, true));

let one = Arc::new(Field::new(
"1",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
false,
));
let two = Arc::new(Field::new(
"2",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
true,
));

let exclamation = Arc::new(Field::new(
"!",
DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
false,
));

let schema = Schema::new(vec![exclamation.clone()]);

// Initialize fields
let a_field = Int64Array::from(vec![Some(0), Some(1)]);
let b_field = Int64Array::from(vec![Some(2), Some(3)]);
let c_field = Int64Array::from(vec![None, Some(4)]);

let one_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);
let two_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);

let exclamation_field = Arc::new(StructArray::from(vec![
(one.clone(), Arc::new(one_field) as ArrayRef),
(two.clone(), Arc::new(two_field) as ArrayRef),
]));

// Normalize top level
let normalized =
RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
.expect("valid conversion")
.normalize(".", Some(1))
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
(
"!.1",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
false,
),
(
"!.2",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
true,
),
])
.expect("valid conversion");

assert_eq!(expected, normalized);

// Normalize all levels
let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
.expect("valid conversion")
.normalize(".", None)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_empty() {
let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
.normalize(".", Some(0))
.expect("valid normalization");

let expected = RecordBatch::new_empty(Arc::new(
schema.normalize(".", Some(0)).expect("valid normalization"),
));

assert_eq!(expected, normalized);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
Expand Down
Loading

0 comments on commit 001239d

Please sign in to comment.