From cba045119827085d211cb1e307f8d7b1dab14dfa Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 4 Jul 2023 09:32:00 +0100 Subject: [PATCH] Use upstream result type --- Cargo.toml | 14 +- datafusion-cli/Cargo.lock | 33 ++-- datafusion-cli/Cargo.toml | 14 +- datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/type_coercion/binary.rs | 186 ++++---------------- 5 files changed, 63 insertions(+), 185 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 40aad73c32630..cb34b1289762a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,10 +73,10 @@ panic = 'unwind' rpath = false [patch.crates-io] -arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-arith = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } +arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-arith = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c1023e205bb63..2837147a1cc9a 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -87,7 +87,7 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "ahash", "arrow-arith", @@ -108,7 +108,7 @@ dependencies = [ [[package]] name = "arrow-arith" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -122,7 +122,7 @@ dependencies = [ [[package]] name = "arrow-array" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "ahash", "arrow-buffer", @@ -138,7 +138,7 @@ dependencies = [ [[package]] name = "arrow-buffer" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "half", "num", @@ -147,7 +147,7 @@ dependencies = [ [[package]] name = "arrow-cast" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -164,7 +164,7 @@ dependencies = [ [[package]] name = "arrow-csv" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -182,7 +182,7 @@ dependencies = [ [[package]] name = "arrow-data" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-buffer", "arrow-schema", @@ -193,7 +193,7 @@ dependencies = [ [[package]] name = "arrow-ipc" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -206,7 +206,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -225,7 +225,7 @@ dependencies = [ [[package]] name = "arrow-ord" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -239,7 +239,7 @@ dependencies = [ [[package]] name = "arrow-row" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "ahash", "arrow-array", @@ -253,12 +253,12 @@ dependencies = [ [[package]] name = "arrow-schema" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" [[package]] name = "arrow-select" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -270,7 +270,7 @@ dependencies = [ [[package]] name = "arrow-string" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "arrow-array", "arrow-buffer", @@ -1108,6 +1108,7 @@ version = "27.0.0" dependencies = [ "ahash", "arrow", + "arrow-arith", "datafusion-common", "lazy_static", "sqlparser", @@ -2220,7 +2221,7 @@ dependencies = [ [[package]] name = "parquet" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" dependencies = [ "ahash", "arrow-array", @@ -3760,4 +3761,4 @@ dependencies = [ [[patch.unused]] name = "arrow-flight" version = "43.0.0" -source = "git+https://github.com/tustvold/arrow-rs.git?rev=41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6#41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=f402dd3fe427453e43d7df8a5974151ca93d3306#f402dd3fe427453e43d7df8a5974151ca93d3306" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 3444722b0f68a..2283dc9a68e39 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -51,10 +51,10 @@ predicates = "3.0" rstest = "0.17" [patch.crates-io] -arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-arith = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } -parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "41a7e5ac8691b181199e9d2fc8b90c383c6a8cd6" } +arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-arith = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } +parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "f402dd3fe427453e43d7df8a5974151ca93d3306" } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index a5bcaa552fb36..0e4c331b8da52 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -37,6 +37,7 @@ path = "src/lib.rs" [dependencies] ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { workspace = true } +arrow-arith = { workspace = true } datafusion-common = { path = "../common", version = "27.0.0" } lazy_static = { version = "^1.4.0" } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index c4188af0cf40d..4a6a7933ea2be 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,6 +17,7 @@ //! Coercion rules for matching argument types for binary operators +use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, @@ -118,8 +119,23 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result Operator::Multiply | Operator::Divide| Operator::Modulo => { - // TODO: this logic would be easier to follow if the functions were inlined - if let Some(ret) = mathematics_temporal_result_type(lhs, rhs) { + let get_result = |lhs, rhs| { + use arrow_arith::operation::*; + let l = new_empty_array(lhs); + let r = new_empty_array(rhs); + + let result = match op { + Operator::Plus => add_wrapping(&l, &r), + Operator::Minus => sub_wrapping(&l, &r), + Operator::Multiply => mul_wrapping(&l, &r), + Operator::Divide => div(&l, &r), + Operator::Modulo => rem(&l, &r), + _ => unreachable!(), + }; + result.map(|x| x.data_type().clone()) + }; + + if let Ok(ret) = get_result(lhs, rhs) { // Temporal arithmetic, e.g. Date32 + Interval Ok(Signature{ lhs: lhs.clone(), @@ -129,9 +145,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } else if let Some(coerced) = temporal_coercion(lhs, rhs) { // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp - let ret = mathematics_temporal_result_type(&coerced, &coerced).ok_or_else(|| { + let ret = get_result(&coerced, &coerced).map_err(|e| { DataFusionError::Plan(format!( - "Cannot get result type for temporal operation {coerced} {op} {coerced}" + "Cannot get result type for temporal operation {coerced} {op} {coerced}: {e}" )) })?; Ok(Signature{ @@ -141,9 +157,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result }) } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) - let ret = decimal_op_mathematics_type(op, &lhs, &rhs).ok_or_else(|| { + let ret = get_result(&lhs, &rhs).map_err(|e| { DataFusionError::Plan(format!( - "Cannot get result type for decimal operation {lhs} {op} {rhs}" + "Cannot get result type for decimal operation {lhs} {op} {rhs}: {e}" )) })?; Ok(Signature{ @@ -163,43 +179,6 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } } -/// Returns the result type of applying mathematics operations such as -/// `+` to arguments of `lhs_type` and `rhs_type`. -fn mathematics_temporal_result_type( - lhs_type: &DataType, - rhs_type: &DataType, -) -> Option { - use arrow::datatypes::DataType::*; - use arrow::datatypes::IntervalUnit::*; - use arrow::datatypes::TimeUnit::*; - - match (lhs_type, rhs_type) { - // datetime +/- interval - (Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()), - (Timestamp(_, _), Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date32) => Some(rhs_type.clone()), - (Date32, Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date64) => Some(rhs_type.clone()), - (Date64, Interval(_)) => Some(lhs_type.clone()), - // interval +/- - (Interval(l), Interval(h)) if l == h => Some(lhs_type.clone()), - (Interval(_), Interval(_)) => Some(Interval(MonthDayNano)), - // timestamp - timestamp - (Timestamp(Second, _), Timestamp(Second, _)) - | (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => { - Some(Interval(DayTime)) - } - (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) - | (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => { - Some(Interval(MonthDayNano)) - } - // date - date - (Date32, Date32) => Some(Interval(DayTime)), - (Date64, Date64) => Some(Interval(MonthDayNano)), - _ => None, - } -} - /// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( lhs: &DataType, @@ -517,78 +496,6 @@ fn create_decimal_type(precision: u8, scale: i8) -> DataType { ) } -/// Returns the output type of applying mathematics operations on two decimal types. -/// The rule is from spark. Note that this is different to the coerced type applied -/// to two sides of the arithmetic operation. -pub fn decimal_op_mathematics_type( - mathematics_op: &Operator, - left_decimal_type: &DataType, - right_decimal_type: &DataType, -) -> Option { - use arrow::datatypes::DataType::*; - match (left_decimal_type, right_decimal_type) { - // The coercion rule from spark - // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 - (Decimal128(p1, s1), Decimal128(p2, s2)) => { - match mathematics_op { - Operator::Plus | Operator::Minus => { - // max(s1, s2) - let result_scale = *s1.max(s2); - // max(s1, s2) + max(p1-s1, p2-s2) + 1 - let result_precision = - result_scale + (*p1 as i8 - *s1).max(*p2 as i8 - *s2) + 1; - Some(create_decimal_type(result_precision as u8, result_scale)) - } - Operator::Multiply => { - // s1 + s2 - let result_scale = *s1 + *s2; - // p1 + p2 + 1 - let result_precision = *p1 + *p2 + 1; - Some(create_decimal_type(result_precision, result_scale)) - } - Operator::Divide => { - // max(6, s1 + p2 + 1) - let result_scale = 6.max(*s1 + *p2 as i8 + 1); - // p1 - s1 + s2 + max(6, s1 + p2 + 1) - let result_precision = result_scale + *p1 as i8 - *s1 + *s2; - Some(create_decimal_type(result_precision as u8, result_scale)) - } - Operator::Modulo => { - // max(s1, s2) - let result_scale = *s1.max(s2); - // min(p1-s1, p2-s2) + max(s1, s2) - let result_precision = - result_scale + (*p1 as i8 - *s1).min(*p2 as i8 - *s2); - Some(create_decimal_type(result_precision as u8, result_scale)) - } - _ => None, - } - } - (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { - decimal_op_mathematics_type( - mathematics_op, - lhs_value_type.as_ref(), - rhs_value_type.as_ref(), - ) - } - (Dictionary(key_type, value_type), _) => { - let value_type = decimal_op_mathematics_type( - mathematics_op, - value_type.as_ref(), - right_decimal_type, - ); - value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))) - } - (_, Dictionary(_, value_type)) => decimal_op_mathematics_type( - mathematics_op, - left_decimal_type, - value_type.as_ref(), - ), - _ => None, - } -} - /// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { use arrow::datatypes::DataType::*; @@ -788,8 +695,8 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { mod tests { use arrow::datatypes::DataType; - use datafusion_common::DataFusionError; use datafusion_common::Result; + use datafusion_common::{assert_contains, DataFusionError}; use crate::Operator; @@ -887,21 +794,6 @@ mod tests { coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), DataType::Decimal128(30, 15) ); - - let left_decimal_type = DataType::Decimal128(10, 3); - let right_decimal_type = DataType::Decimal128(20, 4); - let op = Operator::Multiply; - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(31, 7), result.unwrap()); - let op = Operator::Divide; - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(35, 24), result.unwrap()); - let op = Operator::Modulo; - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(11, 4), result.unwrap()); } #[test] @@ -961,11 +853,14 @@ mod tests { assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)"); assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)"); - let (lhs, rhs) = - get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) - .unwrap(); - assert_eq!(lhs.to_string(), "Date64"); - assert_eq!(rhs.to_string(), "Date64"); + let err = get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) + .unwrap_err() + .to_string(); + + assert_contains!( + &err, + "Cannot get result type for temporal operation Date64 + Date64" + ); Ok(()) } @@ -1164,20 +1059,13 @@ mod tests { fn test_math_decimal_coercion_rule( lhs_type: DataType, rhs_type: DataType, - mathematics_op: Operator, expected_lhs_type: DataType, expected_rhs_type: DataType, - expected_output_type: DataType, ) { // The coerced types for lhs and rhs, if any of them is not decimal let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); assert_eq!(lhs_type, expected_lhs_type); assert_eq!(rhs_type, expected_rhs_type); - - // The output type of decimal math expression - let output_type = - decimal_op_mathematics_type(&mathematics_op, &lhs_type, &rhs_type).unwrap(); - assert_eq!(output_type, expected_output_type); } #[test] @@ -1185,55 +1073,43 @@ mod tests { test_math_decimal_coercion_rule( DataType::Decimal128(10, 2), DataType::Decimal128(10, 2), - Operator::Plus, DataType::Decimal128(10, 2), DataType::Decimal128(10, 2), - DataType::Decimal128(11, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Plus, DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), - DataType::Decimal128(13, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Minus, DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), - DataType::Decimal128(13, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Multiply, DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), - DataType::Decimal128(21, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Divide, DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), - DataType::Decimal128(23, 11), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Modulo, DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), - DataType::Decimal128(10, 2), ); Ok(())