From b74a0621bb1e10cc52ecfd2d0fa051e0045531f5 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Wed, 8 Jan 2025 21:01:16 +0530 Subject: [PATCH] for 'array_repeat' if the count value is 0, return NULL instead of empty array --- datafusion/functions-nested/src/repeat.rs | 22 +++++++++++++++++--- datafusion/sqllogictest/test_files/array.slt | 6 +++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index da0aa5f12fde..6974499c54b0 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -24,7 +24,7 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray, OffsetSizeTrait, }; -use arrow_buffer::OffsetBuffer; +use arrow_buffer::{BooleanBufferBuilder, OffsetBuffer}; use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; @@ -169,6 +169,7 @@ fn general_repeat( ) -> Result { let data_type = array.data_type(); let mut new_values = vec![]; + let mut null_bits = BooleanBufferBuilder::new(array.len()); let count_vec = count_array .values() @@ -178,6 +179,13 @@ fn general_repeat( .collect::>(); for (row_index, &count) in count_vec.iter().enumerate() { + if count == 0 { + null_bits.append(false); + new_values.push(new_null_array(data_type, 0)); + continue; + } + + null_bits.append(true); let repeated_array = if array.is_null(row_index) { new_null_array(data_type, count) } else { @@ -203,7 +211,7 @@ fn general_repeat( Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::from_lengths(count_vec), values, - None, + Some(null_bits.finish().into()), )?)) } @@ -224,6 +232,7 @@ fn general_list_repeat( let data_type = list_array.data_type(); let value_type = list_array.value_type(); let mut new_values = vec![]; + let mut null_bits = BooleanBufferBuilder::new(list_array.len()); let count_vec = count_array .values() @@ -233,6 +242,13 @@ fn general_list_repeat( .collect::>(); for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + if count == 0 { + null_bits.append(false); + new_values.push(new_null_array(data_type, 0)); + continue; + } + + null_bits.append(true); let list_arr = match list_array_row { Some(list_array_row) => { let original_data = list_array_row.to_data(); @@ -271,6 +287,6 @@ fn general_list_repeat( Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::::from_lengths(lengths), values, - None, + Some(null_bits.finish().into()), )?)) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 90003b28572a..e943db532b3e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2729,7 +2729,7 @@ select list_repeat('rust', 4), list_repeat(null, 0); ---- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] [] [rust, rust, rust, rust] [] +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] NULL [rust, rust, rust, rust] NULL # array_repeat scalar function #2 (element as list) query ???? @@ -2783,7 +2783,7 @@ from array_repeat_table; [1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] [, ] [, ] [, ] [, ] [, , ] [[1], [1]] [2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] -[] [] [] [] [3, 3, 3] [] +NULL NULL NULL NULL [3, 3, 3] NULL query ?????? select @@ -2798,7 +2798,7 @@ from large_array_repeat_table; [1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] [, ] [, ] [, ] [, ] [, , ] [[1], [1]] [2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] -[] [] [] [] [3, 3, 3] [] +NULL NULL NULL NULL [3, 3, 3] NULL statement ok drop table array_repeat_table;