Skip to content

Commit

Permalink
feat: Enforce the uniqueness of map key name for the map/make_map fun…
Browse files Browse the repository at this point in the history
…ction (#12153)

* feat: Enforce the uniqueness of map key name for the map/make_map function

* chore: Update tests

* chore

* chore: Update tests for nested type

* refactor

* chore

* fix: Check unique key for the make_map function earlier

* fix: Update bench

* chore: Clean UP
  • Loading branch information
Weijun-H authored Sep 4, 2024
1 parent 6bbad7e commit 9ab2724
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 30 deletions.
7 changes: 6 additions & 1 deletion datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,15 @@ pub fn arrays_into_list_array(
}

/// Helper function to convert a ListArray into a vector of ArrayRefs.
pub fn list_to_arrays<O: OffsetSizeTrait>(a: ArrayRef) -> Vec<ArrayRef> {
pub fn list_to_arrays<O: OffsetSizeTrait>(a: &ArrayRef) -> Vec<ArrayRef> {
a.as_list::<O>().iter().flatten().collect::<Vec<_>>()
}

/// Helper function to convert a FixedSizeListArray into a vector of ArrayRefs.
pub fn fixed_size_list_to_arrays(a: &ArrayRef) -> Vec<ArrayRef> {
a.as_fixed_size_list().iter().flatten().collect::<Vec<_>>()
}

/// Get the base type of a data type.
///
/// Example
Expand Down
20 changes: 12 additions & 8 deletions datafusion/functions-nested/benches/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_schema::{DataType, Field};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::prelude::ThreadRng;
use rand::Rng;
use std::collections::HashSet;
use std::sync::Arc;

use datafusion_common::ScalarValue;
Expand All @@ -32,19 +33,22 @@ use datafusion_functions_nested::map::map_udf;
use datafusion_functions_nested::planner::NestedFunctionPlanner;

fn keys(rng: &mut ThreadRng) -> Vec<String> {
let mut keys = vec![];
for _ in 0..1000 {
keys.push(rng.gen_range(0..9999).to_string());
let mut keys = HashSet::with_capacity(1000);

while keys.len() < 1000 {
keys.insert(rng.gen_range(0..10000).to_string());
}
keys

keys.into_iter().collect()
}

fn values(rng: &mut ThreadRng) -> Vec<i32> {
let mut values = vec![];
for _ in 0..1000 {
values.push(rng.gen_range(0..9999));
let mut values = HashSet::with_capacity(1000);

while values.len() < 1000 {
values.insert(rng.gen_range(0..10000));
}
values
values.into_iter().collect()
}

fn criterion_benchmark(c: &mut Criterion) {
Expand Down
79 changes: 58 additions & 21 deletions datafusion/functions-nested/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
// under the License.

use std::any::Any;
use std::collections::VecDeque;
use std::collections::{HashSet, VecDeque};
use std::sync::Arc;

use arrow::array::ArrayData;
use arrow_array::{Array, ArrayRef, MapArray, OffsetSizeTrait, StructArray};
use arrow_buffer::{Buffer, ToByteSlice};
use arrow_schema::{DataType, Field, SchemaBuilder};

use datafusion_common::{exec_err, ScalarValue};
use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};

Expand All @@ -51,24 +52,64 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool {
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
}

fn make_map_batch(args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
fn make_map_batch(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"make_map requires exactly 2 arguments, got {} instead",
args.len()
);
}

let data_type = args[0].data_type();
let can_evaluate_to_const = can_evaluate_to_const(args);
let key = get_first_array_ref(&args[0])?;
let value = get_first_array_ref(&args[1])?;
make_map_batch_internal(key, value, can_evaluate_to_const, data_type)

// check the keys array is unique
let keys = get_first_array_ref(&args[0])?;
if keys.null_count() > 0 {
return exec_err!("map key cannot be null");
}
let key_array = keys.as_ref();

match &args[0] {
ColumnarValue::Array(_) => {
let row_keys = match key_array.data_type() {
DataType::List(_) => list_to_arrays::<i32>(&keys),
DataType::LargeList(_) => list_to_arrays::<i64>(&keys),
DataType::FixedSizeList(_, _) => fixed_size_list_to_arrays(&keys),
data_type => {
return exec_err!(
"Expected list, large_list or fixed_size_list, got {:?}",
data_type
);
}
};

row_keys
.iter()
.try_for_each(|key| check_unique_keys(key.as_ref()))?;
}
ColumnarValue::Scalar(_) => {
check_unique_keys(key_array)?;
}
}

let values = get_first_array_ref(&args[1])?;
make_map_batch_internal(keys, values, can_evaluate_to_const, args[0].data_type())
}

fn get_first_array_ref(
columnar_value: &ColumnarValue,
) -> datafusion_common::Result<ArrayRef> {
fn check_unique_keys(array: &dyn Array) -> Result<()> {
let mut seen_keys = HashSet::with_capacity(array.len());

for i in 0..array.len() {
let key = ScalarValue::try_from_array(array, i)?;
if seen_keys.contains(&key) {
return exec_err!("map key must be unique, duplicate key found: {}", key);
}
seen_keys.insert(key);
}
Ok(())
}

fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result<ArrayRef> {
match columnar_value {
ColumnarValue::Scalar(value) => match value {
ScalarValue::List(array) => Ok(array.value(0)),
Expand All @@ -85,11 +126,7 @@ fn make_map_batch_internal(
values: ArrayRef,
can_evaluate_to_const: bool,
data_type: DataType,
) -> datafusion_common::Result<ColumnarValue> {
if keys.null_count() > 0 {
return exec_err!("map key cannot be null");
}

) -> Result<ColumnarValue> {
if keys.len() != values.len() {
return exec_err!("map requires key and value lists to have the same length");
}
Expand Down Expand Up @@ -173,7 +210,7 @@ impl ScalarUDFImpl for MapFunc {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() % 2 != 0 {
return exec_err!(
"map requires an even number of arguments, got {} instead",
Expand All @@ -198,11 +235,11 @@ impl ScalarUDFImpl for MapFunc {
))
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_map_batch(args)
}
}
fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType> {
fn get_element_type(data_type: &DataType) -> Result<&DataType> {
match data_type {
DataType::List(element) => Ok(element.data_type()),
DataType::LargeList(element) => Ok(element.data_type()),
Expand Down Expand Up @@ -273,12 +310,12 @@ fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType
fn make_map_array_internal<O: OffsetSizeTrait>(
keys: ArrayRef,
values: ArrayRef,
) -> datafusion_common::Result<ColumnarValue> {
) -> Result<ColumnarValue> {
let mut offset_buffer = vec![O::zero()];
let mut running_offset = O::zero();

let keys = datafusion_common::utils::list_to_arrays::<O>(keys);
let values = datafusion_common::utils::list_to_arrays::<O>(values);
let keys = list_to_arrays::<O>(&keys);
let values = list_to_arrays::<O>(&values);

let mut key_array_vec = vec![];
let mut value_array_vec = vec![];
Expand Down
39 changes: 39 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,45 @@ SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33]);
query error DataFusion error: Execution error: map key cannot be null
SELECT MAP(['POST', 'HEAD', null], [41, 33, 30]);

statement error DataFusion error: Execution error: map key cannot be null
CREATE TABLE duplicated_keys_table
AS VALUES
(MAP {1: [1, NULL, 3], NULL: [4, NULL, 6]});

# Test duplicate keys
# key is a scalar type
query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST
SELECT MAP(['POST', 'HEAD', 'POST'], [41, 33, null]);

query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST
SELECT MAP(make_array('POST', 'HEAD', 'POST'), make_array(41, 33, 30));

query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST
SELECT make_map('POST', 41, 'HEAD', 33, 'POST', 30);

statement error DataFusion error: Execution error: map key must be unique, duplicate key found: 1
CREATE TABLE duplicated_keys_table
AS VALUES
(MAP {1: [1, NULL, 3], 1: [4, NULL, 6]});

statement ok
create table duplicate_keys_table as values
('a', 1, 'a', 10, ['k1', 'k1'], [1, 2]);

query error DataFusion error: Execution error: map key must be unique, duplicate key found: a
SELECT make_map(column1, column2, column3, column4) FROM duplicate_keys_table;

query error DataFusion error: Execution error: map key must be unique, duplicate key found: k1
SELECT map(column5, column6) FROM duplicate_keys_table;

# key is a nested type
query error DataFusion error: Execution error: map key must be unique, duplicate key found: \[1, 2\]
SELECT MAP([[1,2], [1,2], []], [41, 33, null]);

query error DataFusion error: Execution error: map key must be unique, duplicate key found: \[\{1:1\}\]
SELECT MAP([Map {1:'1'}, Map {1:'1'}, Map {2:'2'}], [41, 33, null]);


query ?
SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30));
----
Expand Down

0 comments on commit 9ab2724

Please sign in to comment.