Skip to content

Commit

Permalink
Apply type_union_resolution to array and values (#12753)
Browse files Browse the repository at this point in the history
* cleanup make array coercion rule

Signed-off-by: jayzhan211 <[email protected]>

* change to type union resolution

Signed-off-by: jayzhan211 <[email protected]>

* change value too

Signed-off-by: jayzhan211 <[email protected]>

* fix tpyo

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Oct 5, 2024
1 parent cf76aba commit 8aafa54
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 98 deletions.
22 changes: 6 additions & 16 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,16 @@ fn type_union_resolution_coercion(
let new_value_type = type_union_resolution_coercion(value_type, other_type);
new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t)))
}
(DataType::List(lhs), DataType::List(rhs)) => {
let new_item_type =
type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
}
_ => {
// numeric coercion is the same as comparison coercion, both find the narrowest type
// that can accommodate both types
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| numeric_string_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -507,22 +513,6 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| struct_coercion(lhs_type, rhs_type))
}

/// Coerce `lhs_type` and `rhs_type` to a common type for `VALUES` expression
///
/// For example `VALUES (1, 2), (3.0, 4.0)` where the first row is `Int32` and
/// the second row is `Float64` will coerce to `Float64`
///
pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
// same type => equality is possible
return Some(lhs_type.clone());
}
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use crate::logical_plan::{
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::type_coercion::binary::values_coercion;
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expr_to_columns,
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
Expand All @@ -53,6 +52,7 @@ use datafusion_common::{
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
TableReference, ToDFSchema, UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

use super::dml::InsertOp;
use super::plan::{ColumnUnnestList, ColumnUnnestType};
Expand Down Expand Up @@ -209,7 +209,8 @@ impl LogicalPlanBuilder {
}
if let Some(prev_type) = common_type {
// get common type of each column values.
let Some(new_type) = values_coercion(&data_type, &prev_type) else {
let data_types = vec![prev_type.clone(), data_type.clone()];
let Some(new_type) = type_union_resolution(&data_types) else {
return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}");
};
common_type = Some(new_type);
Expand Down
23 changes: 15 additions & 8 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ pub fn data_types(
try_coerce_types(valid_types, current_types, &signature.type_signature)
}

fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
if let TypeSignature::OneOf(signatures) = type_signature {
return signatures.iter().all(is_well_supported_signature);
}

matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
| TypeSignature::Any(_)
)
}

fn try_coerce_types(
valid_types: Vec<Vec<DataType>>,
current_types: &[DataType],
Expand All @@ -175,14 +189,7 @@ fn try_coerce_types(
let mut valid_types = valid_types;

// Well-supported signature that returns exact valid types.
if !valid_types.is_empty()
&& matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
)
{
if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
// exact valid types
assert_eq!(valid_types.len(), 1);
let valid_types = valid_types.swap_remove(0);
Expand Down
54 changes: 23 additions & 31 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`ScalarUDFImpl`] definitions for `make_array` function.
use std::vec;
use std::{any::Any, sync::Arc};

use arrow::array::{ArrayData, Capacities, MutableArrayData};
Expand All @@ -26,9 +27,8 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::internal_err;
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -82,19 +82,12 @@ impl ScalarUDFImpl for MakeArray {
match arg_types.len() {
0 => Ok(empty_array_type()),
_ => {
let mut expr_type = DataType::Null;
for arg_type in arg_types {
if !arg_type.equals_datatype(&DataType::Null) {
expr_type = arg_type.clone();
break;
}
}

if expr_type.is_null() {
expr_type = DataType::Int64;
}

Ok(List(Arc::new(Field::new("item", expr_type, true))))
// At this point, all the type in array should be coerced to the same one
Ok(List(Arc::new(Field::new(
"item",
arg_types[0].to_owned(),
true,
))))
}
}
}
Expand All @@ -112,22 +105,21 @@ impl ScalarUDFImpl for MakeArray {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let new_type = arg_types.iter().skip(1).try_fold(
arg_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
)?;
Ok(vec![new_type; arg_types.len()])
if let Some(new_type) = type_union_resolution(arg_types) {
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Ok(vec![DataType::Int64; arg_types.len()])
} else {
Ok(vec![new_type; arg_types.len()])
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}",
arg_types,
self.name()
)
}
}
}

Expand Down
25 changes: 0 additions & 25 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
self.schema,
&func,
)?;
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(func, new_expr),
)))
Expand Down Expand Up @@ -756,30 +755,6 @@ fn coerce_arguments_for_signature_with_aggregate_udf(
.collect()
}

fn coerce_arguments_for_fun(
expressions: Vec<Expr>,
schema: &DFSchema,
fun: &Arc<ScalarUDF>,
) -> Result<Vec<Expr>> {
// Cast Fixedsizelist to List for array functions
if fun.name() == "make_array" {
expressions
.into_iter()
.map(|expr| {
let data_type = expr.get_type(schema).unwrap();
if let DataType::FixedSizeList(field, _) = data_type {
let to_type = DataType::List(Arc::clone(&field));
expr.cast_to(&to_type, schema)
} else {
Ok(expr)
}
})
.collect()
} else {
Ok(expressions)
}
}

fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
// Given expressions like:
//
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6595,7 +6595,7 @@ select make_array(1, 2.0, null, 3)
query ?
select make_array(1.0, '2', null)
----
[1.0, 2, ]
[1.0, 2.0, ]

### FixedSizeListArray

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/errors.slt
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ from aggregate_test_100
order by c9


statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type
create table foo as values (1), ('foo');
32 changes: 20 additions & 12 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,17 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']);
{[1, 2]: [a, b], [3, 4]: [b]}

query ?
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30);
----
{POST: 41, HEAD: ab, PATCH: 30}
{POST: 41, HEAD: 53, PATCH: 30}

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);

# Map keys can not be NULL
query error
SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30);

query ?
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
----
{POST: 41, HEAD: ab, PATCH: 30}

query ?
SELECT MAKE_MAP()
----
Expand Down Expand Up @@ -517,9 +516,12 @@ query error
SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL];

query ?
SELECT MAP { 'a': 1, 2: 3 };
SELECT MAP { 'a': 1, 'b': 3 };
----
{a: 1, 2: 3}
{a: 1, b: 3}

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT MAP { 'a': 1, 2: 3 };

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
Expand Down Expand Up @@ -610,9 +612,12 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7)
# Tests for map_keys

query ?
SELECT map_keys(MAP { 'a': 1, 2: 3 });
SELECT map_keys(MAP { 'a': 1, 'b': 3 });
----
[a, 2]
[a, b]

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT map_keys(MAP { 'a': 1, 2: 3 });

query ?
SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t;
Expand Down Expand Up @@ -657,8 +662,11 @@ SELECT map_keys(column1) from map_array_table_1;

# Tests for map_values

query ?
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT map_values(MAP { 'a': 1, 2: 3 });

query ?
SELECT map_values(MAP { 'a': 1, 'b': 3 });
----
[1, 3]

Expand Down
10 changes: 8 additions & 2 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,23 @@ VALUES (1),()
statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1
VALUES (1),(1,2)

statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0
query I
VALUES (1),('2')
----
1
2

query R
VALUES (1),(2.0)
----
1
2

statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1
query II
VALUES (1,2), (1,'2')
----
1 2
1 2

query IT
VALUES (1,'a'),(NULL,'b'),(3,'c')
Expand Down

0 comments on commit 8aafa54

Please sign in to comment.