From 572780bbca08c25a5d2f71016b915733e146df04 Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Tue, 6 Jun 2023 18:03:48 +0800 Subject: [PATCH] feat(expr): support `array_position` and `array_replace` for 1d scenario (#10166) --- proto/expr.proto | 2 + src/expr/src/vector_op/array_positions.rs | 115 +++++++++++++++++++ src/expr/src/vector_op/array_replace.rs | 90 +++++++++++++++ src/expr/src/vector_op/mod.rs | 1 + src/frontend/src/binder/expr/function.rs | 2 + src/frontend/src/expr/pure.rs | 2 + src/frontend/src/expr/type_inference/func.rs | 31 +++++ src/tests/regress/data/sql/arrays.sql | 46 ++++---- 8 files changed, 266 insertions(+), 23 deletions(-) create mode 100644 src/expr/src/vector_op/array_replace.rs diff --git a/proto/expr.proto b/proto/expr.proto index 69f0fb0ef3ee9..4b6c5959fa3f5 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -179,6 +179,8 @@ message ExprNode { ARRAY_POSITIONS = 539; TRIM_ARRAY = 540; STRING_TO_ARRAY = 541; + ARRAY_POSITION = 542; + ARRAY_REPLACE = 543; // Int256 functions HEX_TO_INT256 = 560; diff --git a/src/expr/src/vector_op/array_positions.rs b/src/expr/src/vector_op/array_positions.rs index 930e0e0c739ba..fdaffcafe0426 100644 --- a/src/expr/src/vector_op/array_positions.rs +++ b/src/expr/src/vector_op/array_positions.rs @@ -19,6 +19,121 @@ use risingwave_expr_macro::function; use crate::error::ExprError; use crate::Result; +/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if +/// it's not present. +/// +/// Examples: +/// +/// ```slt +/// query I +/// select array_position(array[1, null, 2, null], null); +/// ---- +/// 2 +/// +/// query I +/// select array_position(array[3, 4, 5], 2); +/// ---- +/// NULL +/// +/// query I +/// select array_position(null, 4); +/// ---- +/// NULL +/// +/// query I +/// select array_position(null, null); +/// ---- +/// NULL +/// +/// query I +/// select array_position('{yes}', true); +/// ---- +/// 1 +/// +/// # Like in PostgreSQL, searching `int` in multidimensional array is disallowed. +/// statement error +/// select array_position(array[array[1, 2], array[3, 4]], 1); +/// +/// # Unlike in PostgreSQL, it is okay to search `int[]` inside `int[][]`. +/// query I +/// select array_position(array[array[1, 2], array[3, 4]], array[3, 4]); +/// ---- +/// 2 +/// +/// statement error +/// select array_position(array[3, 4], true); +/// +/// query I +/// select array_position(array[3, 4], 4.0); +/// ---- +/// 2 +/// ``` +#[function("array_position(list, *) -> int32")] +fn array_position<'a, T: ScalarRef<'a>>( + array: Option>, + element: Option, +) -> Result> { + array_position_common(array, element, 0) +} + +/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if +/// it's not present. The search begins at the third argument. +/// +/// Examples: +/// +/// ```slt +/// statement error +/// select array_position(array[1, null, 2, null], null, false); +/// +/// statement error +/// select array_position(array[1, null, 2, null], null, null::int); +/// +/// query II +/// select v, array_position(array[1, null, 2, null], null, v) from generate_series(-1, 5) as t(v); +/// ---- +/// -1 2 +/// 0 2 +/// 1 2 +/// 2 2 +/// 3 4 +/// 4 4 +/// 5 NULL +/// ``` +#[function("array_position(list, *, int32) -> int32")] +fn array_position_start<'a, T: ScalarRef<'a>>( + array: Option>, + element: Option, + start: Option, +) -> Result> { + let start = match start { + None => { + return Err(ExprError::InvalidParam { + name: "start", + reason: "initial position must not be null".into(), + }) + } + Some(start) => (start.max(1) - 1) as usize, + }; + array_position_common(array, element, start) +} + +fn array_position_common<'a, T: ScalarRef<'a>>( + array: Option>, + element: Option, + skip: usize, +) -> Result> { + let Some(left) = array else { return Ok(None) }; + if i32::try_from(left.len()).is_err() { + return Err(ExprError::CastOutOfRange("invalid array length")); + } + + Ok(left + .iter() + .skip(skip) + .position(|item| item == element.map(Into::into)) + .map(|idx| (idx + 1 + skip) as _)) +} + /// Returns an array of the subscripts of all occurrences of the second argument in the array /// given as first argument. Note the behavior is slightly different from PG. /// diff --git a/src/expr/src/vector_op/array_replace.rs b/src/expr/src/vector_op/array_replace.rs new file mode 100644 index 0000000000000..a0f302fc39f25 --- /dev/null +++ b/src/expr/src/vector_op/array_replace.rs @@ -0,0 +1,90 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::{ListRef, ListValue}; +use risingwave_common::types::{ScalarRef, ToOwnedDatum}; +use risingwave_expr_macro::function; + +/// Replaces each array element equal to the second argument with the third argument. +/// +/// Examples: +/// +/// ```slt +/// query T +/// select array_replace(array[7, null, 8, null], null, 0.5); +/// ---- +/// {7,0.5,8,0.5} +/// +/// query T +/// select array_replace(null, 1, 5); +/// ---- +/// NULL +/// +/// query T +/// select array_replace(null, null, null); +/// ---- +/// NULL +/// +/// statement error +/// select array_replace(array[3, null, 4], true, false); +/// +/// # Replacing `int` in multidimensional array is not supported yet. (OK in PostgreSQL) +/// statement error +/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], 3, 9); +/// +/// # Unlike PostgreSQL, it is okay to replace `int[][]` inside `int[][][]`. +/// query T +/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[array[4, 5], array[6, 7]], array[array[2, 3], array[4, 5]]); +/// ---- +/// {{{0,1},{2,3}},{{2,3},{4,5}}} +/// +/// # Replacing `int[]` inside `int[][][]` is not supported by either PostgreSQL or RisingWave. +/// # This may or may not be supported later, whichever makes the `int` support above simpler. +/// statement error +/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[4, 5], array[8, 9]); +/// ``` +#[function("array_replace(list, boolean, boolean) -> list")] +#[function("array_replace(list, int16, int16) -> list")] +#[function("array_replace(list, int32, int32) -> list")] +#[function("array_replace(list, int64, int64) -> list")] +#[function("array_replace(list, decimal, decimal) -> list")] +#[function("array_replace(list, float32, float32) -> list")] +#[function("array_replace(list, float64, float64) -> list")] +#[function("array_replace(list, varchar, varchar) -> list")] +#[function("array_replace(list, bytea, bytea) -> list")] +#[function("array_replace(list, time, time) -> list")] +#[function("array_replace(list, interval, interval) -> list")] +#[function("array_replace(list, date, date) -> list")] +#[function("array_replace(list, timestamp, timestamp) -> list")] +#[function("array_replace(list, timestamptz, timestamptz) -> list")] +#[function("array_replace(list, list, list) -> list")] +#[function("array_replace(list, struct, struct) -> list")] +#[function("array_replace(list, jsonb, jsonb) -> list")] +#[function("array_replace(list, int256, int256) -> list")] +fn array_replace<'a, T: ScalarRef<'a>>( + arr: Option>, + elem_from: Option, + elem_to: Option, +) -> Option { + arr.map(|arr| { + ListValue::new( + arr.iter() + .map(|x| match x == elem_from.map(Into::into) { + true => elem_to.map(Into::into).to_owned_datum(), + false => x.to_owned_datum(), + }) + .collect(), + ) + }) +} diff --git a/src/expr/src/vector_op/mod.rs b/src/expr/src/vector_op/mod.rs index 8129a794449eb..4bc147cf3caec 100644 --- a/src/expr/src/vector_op/mod.rs +++ b/src/expr/src/vector_op/mod.rs @@ -19,6 +19,7 @@ pub mod array_length; pub mod array_positions; pub mod array_range_access; pub mod array_remove; +pub mod array_replace; pub mod ascii; pub mod bitwise_op; pub mod cardinality; diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index b8c1f749b2245..b46037a12553c 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -524,6 +524,8 @@ impl Binder { ("array_length", raw_call(ExprType::ArrayLength)), ("cardinality", raw_call(ExprType::Cardinality)), ("array_remove", raw_call(ExprType::ArrayRemove)), + ("array_replace", raw_call(ExprType::ArrayReplace)), + ("array_position", raw_call(ExprType::ArrayPosition)), ("array_positions", raw_call(ExprType::ArrayPositions)), ("trim_array", raw_call(ExprType::TrimArray)), // int256 diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index f3361d2f3824d..018b2522bc7e9 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -148,6 +148,8 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Cardinality | expr_node::Type::TrimArray | expr_node::Type::ArrayRemove + | expr_node::Type::ArrayReplace + | expr_node::Type::ArrayPosition | expr_node::Type::HexToInt256 | expr_node::Type::JsonbAccessInner | expr_node::Type::JsonbAccessStr diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index dd0a4b6410d44..a75abdf680f4d 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -544,6 +544,37 @@ fn infer_type_for_special( .into()), } } + ExprType::ArrayReplace => { + ensure_arity!("array_replace", | inputs | == 3); + let common_type = align_array_and_element(0, &[1, 2], inputs); + match common_type { + Ok(casted) => Ok(Some(casted)), + Err(_) => Err(ErrorCode::BindError(format!( + "Cannot replace {} with {} in {}", + inputs[1].return_type(), + inputs[2].return_type(), + inputs[0].return_type(), + )) + .into()), + } + } + ExprType::ArrayPosition => { + ensure_arity!("array_position", 2 <= | inputs | <= 3); + if let Some(start) = inputs.get_mut(2) { + let owned = std::mem::replace(start, ExprImpl::literal_bool(false)); + *start = owned.cast_implicit(DataType::Int32)?; + } + let common_type = align_array_and_element(0, &[1], inputs); + match common_type { + Ok(_) => Ok(Some(DataType::Int32)), + Err(_) => Err(ErrorCode::BindError(format!( + "Cannot get position of {} in {}", + inputs[1].return_type(), + inputs[0].return_type() + )) + .into()), + } + } ExprType::ArrayPositions => { ensure_arity!("array_positions", | inputs | == 2); let common_type = align_array_and_element(0, &[1], inputs); diff --git a/src/tests/regress/data/sql/arrays.sql b/src/tests/regress/data/sql/arrays.sql index 0ceffdc129ab0..0e8f927769f5d 100644 --- a/src/tests/regress/data/sql/arrays.sql +++ b/src/tests/regress/data/sql/arrays.sql @@ -261,14 +261,14 @@ SELECT array_cat(ARRAY[1,2], ARRAY[3,4]) AS "{1,2,3,4}"; SELECT array_cat(ARRAY[1,2], ARRAY[[3,4],[5,6]]) AS "{{1,2},{3,4},{5,6}}"; SELECT array_cat(ARRAY[[3,4],[5,6]], ARRAY[1,2]) AS "{{3,4},{5,6},{1,2}}"; ---@ SELECT array_position(ARRAY[1,2,3,4,5], 4); ---@ SELECT array_position(ARRAY[5,3,4,2,1], 4); ---@ SELECT array_position(ARRAY[[1,2],[3,4]], 3); ---@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon'); ---@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat'); ---@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL); ---@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL); ---@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat'); +SELECT array_position(ARRAY[1,2,3,4,5], 4); +SELECT array_position(ARRAY[5,3,4,2,1], 4); +SELECT array_position(ARRAY[[1,2],[3,4]], 3); +SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon'); +SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat'); +SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL); +SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL); +SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat'); SELECT array_positions(NULL, 10); SELECT array_positions(NULL, NULL::int); @@ -296,15 +296,15 @@ SELECT array_positions(ARRAY[1,2,3,NULL,5,6,1,2,3,NULL,5,6], NULL); --@ SELECT array_position('[2:4]={1,2,3}'::int[], 1); --@ SELECT array_positions('[2:4]={1,2,3}'::int[], 1); ---@ ---@ SELECT ---@ array_position(ids, (1, 1)), ---@ array_positions(ids, (1, 1)) ---@ FROM ---@ (VALUES ---@ (ARRAY[(0, 0), (1, 1)]), ---@ (ARRAY[(1, 1)]) ---@ ) AS f (ids); + +SELECT + array_position(ids, (1, 1)), + array_positions(ids, (1, 1)) + FROM +(VALUES + (ARRAY[(0, 0), (1, 1)]), + (ARRAY[(1, 1)]) +) AS f (ids); -- operators --@ SELECT a FROM arrtest WHERE b = ARRAY[[[113,142],[1,147]]]; @@ -625,12 +625,12 @@ select array_remove(array['A','CC','D','C','RR'], 'RR'); select array_remove(array[1.0, 2.1, 3.3], 1); select array_remove('{{1,2,2},{1,4,3}}', 2); -- not allowed select array_remove(array['X','X','X'], 'X') = '{}'; ---@ select array_replace(array[1,2,5,4],5,3); ---@ select array_replace(array[1,2,5,4],5,NULL); ---@ select array_replace(array[1,2,NULL,4,NULL],NULL,5); ---@ select array_replace(array['A','B','DD','B'],'B','CC'); ---@ select array_replace(array[1,NULL,3],NULL,NULL); ---@ select array_replace(array['AB',NULL,'CDE'],NULL,'12'); +select array_replace(array[1,2,5,4],5,3); +select array_replace(array[1,2,5,4],5,NULL); +select array_replace(array[1,2,NULL,4,NULL],NULL,5); +select array_replace(array['A','B','DD','B'],'B','CC'); +select array_replace(array[1,NULL,3],NULL,NULL); +select array_replace(array['AB',NULL,'CDE'],NULL,'12'); -- array(select array-value ...) --@ select array(select array[i,i/2] from generate_series(1,5) i);